1use crate::data::{JaggedMatrix, Matrix};
2use crate::gradientbooster::GrowPolicy;
3use crate::grower::Grower;
4use crate::histogram::HistogramMatrix;
5use crate::node::{Node, SplittableNode};
6use crate::partial_dependence::tree_partial_dependence;
7use crate::sampler::SampleMethod;
8use crate::splitter::Splitter;
9use crate::utils::fast_f64_sum;
10use crate::utils::{gain, odds, weight};
11use rayon::prelude::*;
12use serde::{Deserialize, Serialize};
13use std::collections::{BinaryHeap, HashMap, VecDeque};
14use std::fmt::{self, Display};
15
16#[derive(Deserialize, Serialize)]
17pub struct Tree {
18 pub nodes: Vec<Node>,
19}
20
21impl Default for Tree {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl Tree {
28 pub fn new() -> Self {
29 Tree { nodes: Vec::new() }
30 }
31
32 #[allow(clippy::too_many_arguments)]
33 pub fn fit<T: Splitter>(
34 &mut self,
35 data: &Matrix<u16>,
36 mut index: Vec<usize>,
37 col_index: &[usize],
38 cuts: &JaggedMatrix<f64>,
39 grad: &[f32],
40 hess: &[f32],
41 splitter: &T,
42 max_leaves: usize,
43 max_depth: usize,
44 parallel: bool,
45 sample_method: &SampleMethod,
46 grow_policy: &GrowPolicy,
47 ) {
48 let (gradient_sum, hessian_sum, sort) = match sample_method {
52 SampleMethod::None => (fast_f64_sum(grad), fast_f64_sum(hess), false),
55 _ => {
56 let mut gs: f64 = 0.;
58 let mut hs: f64 = 0.;
59 for i in index.iter() {
60 let i_ = *i;
61 gs += grad[i_] as f64;
62 hs += hess[i_] as f64;
63 }
64 (gs as f32, hs as f32, true)
65 }
66 };
67
68 let mut n_nodes = 1;
69 let root_gain = gain(&splitter.get_l2(), gradient_sum, hessian_sum);
70 let root_weight = weight(
71 &splitter.get_l1(),
72 &splitter.get_l2(),
73 &splitter.get_max_delta_step(),
74 gradient_sum,
75 hessian_sum,
76 );
77 let root_hists =
79 HistogramMatrix::new(data, cuts, grad, hess, &index, col_index, parallel, sort);
80 let root_node = SplittableNode::new(
81 0,
82 root_hists,
83 root_weight,
84 root_gain,
85 gradient_sum,
86 hessian_sum,
87 0,
88 0,
89 index.len(),
90 f32::NEG_INFINITY,
91 f32::INFINITY,
92 );
93 self.nodes
95 .push(root_node.as_node(splitter.get_learning_rate()));
96 let mut n_leaves = 1;
97
98 let mut growable: Box<dyn Grower> = match grow_policy {
99 GrowPolicy::DepthWise => Box::<VecDeque<SplittableNode>>::default(),
100 GrowPolicy::LossGuide => Box::<BinaryHeap<SplittableNode>>::default(),
101 };
102
103 growable.add_node(root_node);
104 while !growable.is_empty() {
105 if (n_leaves + splitter.new_leaves_added()) > max_leaves {
107 break;
108 }
109 let mut node = growable.get_next_node();
115 let n_idx = node.num;
116
117 let depth = node.depth + 1;
118
119 if depth > max_depth {
123 continue;
124 }
125
126 n_leaves -= 1;
132
133 let new_nodes = splitter.split_node(
134 &n_nodes, &mut node, &mut index, col_index, data, cuts, grad, hess, parallel,
135 );
136
137 let n_new_nodes = new_nodes.len();
138 if n_new_nodes == 0 {
139 n_leaves += 1;
140 } else {
141 self.nodes[n_idx].make_parent_node(node);
142 n_leaves += n_new_nodes;
143 n_nodes += n_new_nodes;
144 for n in new_nodes {
145 self.nodes.push(n.as_node(splitter.get_learning_rate()));
146 if !n.is_missing_leaf {
147 growable.add_node(n)
148 }
149 }
150 }
151 }
152
153 splitter.clean_up_splits(self);
155 }
156
157 pub fn predict_contributions_row_probability_change(
158 &self,
159 row: &[f64],
160 contribs: &mut [f64],
161 missing: &f64,
162 current_logodds: f64,
163 ) -> f64 {
164 contribs[contribs.len() - 1] +=
165 odds(current_logodds + self.nodes[0].weight_value as f64) - odds(current_logodds);
166 let mut node_idx = 0;
167 let mut lo = current_logodds;
168 loop {
169 let node = &self.nodes[node_idx];
170 let node_odds = odds(node.weight_value as f64 + current_logodds);
171 if node.is_leaf {
172 lo += node.weight_value as f64;
173 break;
174 }
175 let child_idx = node.get_child_idx(&row[node.split_feature], missing);
177 let child_odds = odds(self.nodes[child_idx].weight_value as f64 + current_logodds);
178 let delta = child_odds - node_odds;
179 contribs[node.split_feature] += delta;
180 node_idx = child_idx;
181 }
182 lo
183 }
184
185 pub fn predict_contributions_row_midpoint_difference(
187 &self,
188 row: &[f64],
189 contribs: &mut [f64],
190 missing: &f64,
191 ) {
192 let mut node_idx = 0;
195 loop {
196 let node = &self.nodes[node_idx];
197 if node.is_leaf {
198 break;
199 }
200 let child_idx = node.get_child_idx(&row[node.split_feature], missing);
209 let child = &self.nodes[child_idx];
210 if node.has_missing_branch() && child_idx == node.missing_node {
213 node_idx = child_idx;
214 continue;
215 }
216 let other_child = if child_idx == node.left_child {
217 &self.nodes[node.right_child]
218 } else {
219 &self.nodes[node.left_child]
220 };
221 let mid = (child.weight_value * child.hessian_sum
222 + other_child.weight_value * other_child.hessian_sum)
223 / (child.hessian_sum + other_child.hessian_sum);
224 let delta = child.weight_value - mid;
225 contribs[node.split_feature] += delta as f64;
226 node_idx = child_idx;
227 }
228 }
229
230 pub fn predict_contributions_row_branch_difference(
232 &self,
233 row: &[f64],
234 contribs: &mut [f64],
235 missing: &f64,
236 ) {
237 let mut node_idx = 0;
240 loop {
241 let node = &self.nodes[node_idx];
242 if node.is_leaf {
243 break;
244 }
245 let child_idx = node.get_child_idx(&row[node.split_feature], missing);
254 if node.has_missing_branch() && child_idx == node.missing_node {
257 node_idx = child_idx;
258 continue;
259 }
260 let other_child = if child_idx == node.left_child {
261 &self.nodes[node.right_child]
262 } else {
263 &self.nodes[node.left_child]
264 };
265 let delta = self.nodes[child_idx].weight_value - other_child.weight_value;
266 contribs[node.split_feature] += delta as f64;
267 node_idx = child_idx;
268 }
269 }
270
271 pub fn predict_contributions_row_mode_difference(
274 &self,
275 row: &[f64],
276 contribs: &mut [f64],
277 missing: &f64,
278 ) {
279 let mut node_idx = 0;
281 loop {
282 let node = &self.nodes[node_idx];
283 if node.is_leaf {
284 break;
285 }
286
287 let child_idx = node.get_child_idx(&row[node.split_feature], missing);
288 if node.has_missing_branch() && child_idx == node.missing_node {
291 node_idx = child_idx;
292 continue;
293 }
294 let left_node = &self.nodes[node.left_child];
295 let right_node = &self.nodes[node.right_child];
296 let child_weight = self.nodes[child_idx].weight_value;
297
298 let delta = if left_node.hessian_sum == right_node.hessian_sum {
299 0.
300 } else if left_node.hessian_sum > right_node.hessian_sum {
301 child_weight - left_node.weight_value
302 } else {
303 child_weight - right_node.weight_value
304 };
305 contribs[node.split_feature] += delta as f64;
306 node_idx = child_idx;
307 }
308 }
309
310 pub fn predict_contributions_row_weight(
311 &self,
312 row: &[f64],
313 contribs: &mut [f64],
314 missing: &f64,
315 ) {
316 contribs[contribs.len() - 1] += self.nodes[0].weight_value as f64;
318 let mut node_idx = 0;
319 loop {
320 let node = &self.nodes[node_idx];
321 if node.is_leaf {
322 break;
323 }
324 let child_idx = node.get_child_idx(&row[node.split_feature], missing);
326 let node_weight = self.nodes[node_idx].weight_value as f64;
327 let child_weight = self.nodes[child_idx].weight_value as f64;
328 let delta = child_weight - node_weight;
329 contribs[node.split_feature] += delta;
330 node_idx = child_idx
331 }
332 }
333
334 pub fn predict_contributions_weight(
335 &self,
336 data: &Matrix<f64>,
337 contribs: &mut [f64],
338 missing: &f64,
339 ) {
340 data.index
342 .par_iter()
343 .zip(contribs.par_chunks_mut(data.cols + 1))
344 .for_each(|(row, contribs)| {
345 self.predict_contributions_row_weight(&data.get_row(*row), contribs, missing)
346 })
347 }
348
349 pub fn predict_contributions_row_average(
351 &self,
352 row: &[f64],
353 contribs: &mut [f64],
354 weights: &[f64],
355 missing: &f64,
356 ) {
357 contribs[contribs.len() - 1] += weights[0];
359 let mut node_idx = 0;
360 loop {
361 let node = &self.nodes[node_idx];
362 if node.is_leaf {
363 break;
364 }
365 let child_idx = node.get_child_idx(&row[node.split_feature], missing);
367 let node_weight = weights[node_idx];
368 let child_weight = weights[child_idx];
369 let delta = child_weight - node_weight;
370 contribs[node.split_feature] += delta;
371 node_idx = child_idx
372 }
373 }
374
375 pub fn predict_contributions_average(
376 &self,
377 data: &Matrix<f64>,
378 contribs: &mut [f64],
379 weights: &[f64],
380 missing: &f64,
381 ) {
382 data.index
384 .par_iter()
385 .zip(contribs.par_chunks_mut(data.cols + 1))
386 .for_each(|(row, contribs)| {
387 self.predict_contributions_row_average(
388 &data.get_row(*row),
389 contribs,
390 weights,
391 missing,
392 )
393 })
394 }
395
396 fn predict_leaf(&self, data: &Matrix<f64>, row: usize, missing: &f64) -> &Node {
397 let mut node_idx = 0;
398 loop {
399 let node = &self.nodes[node_idx];
400 if node.is_leaf {
401 return node;
402 } else {
403 node_idx = node.get_child_idx(data.get(row, node.split_feature), missing);
404 }
405 }
406 }
407
408 pub fn predict_row_from_row_slice(&self, row: &[f64], missing: &f64) -> f64 {
409 let mut node_idx = 0;
410 loop {
411 let node = &self.nodes[node_idx];
412 if node.is_leaf {
413 return node.weight_value as f64;
414 } else {
415 node_idx = node.get_child_idx(&row[node.split_feature], missing);
416 }
417 }
418 }
419
420 fn predict_single_threaded(&self, data: &Matrix<f64>, missing: &f64) -> Vec<f64> {
421 data.index
422 .iter()
423 .map(|i| self.predict_leaf(data, *i, missing).weight_value as f64)
424 .collect()
425 }
426
427 fn predict_parallel(&self, data: &Matrix<f64>, missing: &f64) -> Vec<f64> {
428 data.index
429 .par_iter()
430 .map(|i| self.predict_leaf(data, *i, missing).weight_value as f64)
431 .collect()
432 }
433
434 pub fn predict(&self, data: &Matrix<f64>, parallel: bool, missing: &f64) -> Vec<f64> {
435 if parallel {
436 self.predict_parallel(data, missing)
437 } else {
438 self.predict_single_threaded(data, missing)
439 }
440 }
441
442 pub fn predict_leaf_indices(&self, data: &Matrix<f64>, missing: &f64) -> Vec<usize> {
443 data.index
444 .par_iter()
445 .map(|i| self.predict_leaf(data, *i, missing).num)
446 .collect()
447 }
448
449 pub fn value_partial_dependence(&self, feature: usize, value: f64, missing: &f64) -> f64 {
450 tree_partial_dependence(self, 0, feature, value, 1.0, missing)
451 }
452 fn distribute_node_leaf_weights(&self, i: usize, weights: &mut [f64]) -> f64 {
453 let node = &self.nodes[i];
454 let mut w = node.weight_value as f64;
455 if !node.is_leaf {
456 let left_node = &self.nodes[node.left_child];
457 let right_node = &self.nodes[node.right_child];
458 w = left_node.hessian_sum as f64
459 * self.distribute_node_leaf_weights(node.left_child, weights);
460 w += right_node.hessian_sum as f64
461 * self.distribute_node_leaf_weights(node.right_child, weights);
462 if node.has_missing_branch() {
464 let missing_node = &self.nodes[node.missing_node];
465 w += missing_node.hessian_sum as f64
466 * self.distribute_node_leaf_weights(node.missing_node, weights);
467 }
468 w /= node.hessian_sum as f64;
469 }
470 weights[i] = w;
471 w
472 }
473 pub fn distribute_leaf_weights(&self) -> Vec<f64> {
474 let mut weights = vec![0.; self.nodes.len()];
475 self.distribute_node_leaf_weights(0, &mut weights);
476 weights
477 }
478
479 pub fn get_average_leaf_weights(&self, i: usize) -> f64 {
480 let node = &self.nodes[i];
481 let mut w = node.weight_value as f64;
482 if node.is_leaf {
483 w
484 } else {
485 let left_node = &self.nodes[node.left_child];
486 let right_node = &self.nodes[node.right_child];
487 w = left_node.hessian_sum as f64 * self.get_average_leaf_weights(node.left_child);
488 w += right_node.hessian_sum as f64 * self.get_average_leaf_weights(node.right_child);
489 if node.has_missing_branch() {
491 let missing_node = &self.nodes[node.missing_node];
492 w += missing_node.hessian_sum as f64
493 * self.get_average_leaf_weights(node.missing_node);
494 }
495 w /= node.hessian_sum as f64;
496 w
497 }
498 }
499
500 fn calc_feature_node_stats<F>(
501 &self,
502 calc_stat: &F,
503 node: &Node,
504 stats: &mut HashMap<usize, (f32, usize)>,
505 ) where
506 F: Fn(&Node) -> f32,
507 {
508 if node.is_leaf {
509 return;
510 }
511 stats
512 .entry(node.split_feature)
513 .and_modify(|(v, c)| {
514 *v += calc_stat(node);
515 *c += 1;
516 })
517 .or_insert((calc_stat(node), 1));
518 self.calc_feature_node_stats(calc_stat, &self.nodes[node.left_child], stats);
519 self.calc_feature_node_stats(calc_stat, &self.nodes[node.right_child], stats);
520 if node.has_missing_branch() {
521 self.calc_feature_node_stats(calc_stat, &self.nodes[node.missing_node], stats);
522 }
523 }
524
525 fn get_node_stats<F>(&self, calc_stat: &F, stats: &mut HashMap<usize, (f32, usize)>)
526 where
527 F: Fn(&Node) -> f32,
528 {
529 self.calc_feature_node_stats(calc_stat, &self.nodes[0], stats);
530 }
531
532 pub fn calculate_importance_weight(&self, stats: &mut HashMap<usize, (f32, usize)>) {
533 self.get_node_stats(&|_: &Node| 1., stats);
534 }
535
536 pub fn calculate_importance_gain(&self, stats: &mut HashMap<usize, (f32, usize)>) {
537 self.get_node_stats(&|n: &Node| n.split_gain, stats);
538 }
539
540 pub fn calculate_importance_cover(&self, stats: &mut HashMap<usize, (f32, usize)>) {
541 self.get_node_stats(&|n: &Node| n.hessian_sum, stats);
542 }
543}
544
545impl Display for Tree {
546 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
548 let mut print_buffer: Vec<usize> = vec![0];
549 let mut r = String::new();
550 while let Some(idx) = print_buffer.pop() {
551 let node = &self.nodes[idx];
552 if node.is_leaf {
553 r += format!("{}{}\n", " ".repeat(node.depth).as_str(), node).as_str();
554 } else {
555 r += format!("{}{}\n", " ".repeat(node.depth).as_str(), node).as_str();
556 print_buffer.push(node.right_child);
557 print_buffer.push(node.left_child);
558 if node.has_missing_branch() {
559 print_buffer.push(node.missing_node);
560 }
561 }
562 }
563 write!(f, "{}", r)
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::binning::bin_matrix;
571 use crate::constraints::{Constraint, ConstraintMap};
572 use crate::objective::{LogLoss, ObjectiveFunction};
573 use crate::sampler::{RandomSampler, Sampler};
574 use crate::splitter::MissingImputerSplitter;
575 use crate::utils::precision_round;
576 use rand::rngs::StdRng;
577 use rand::SeedableRng;
578 use std::fs;
579 #[test]
580 fn test_tree_fit_with_subsample() {
581 let file = fs::read_to_string("resources/contiguous_no_missing.csv")
582 .expect("Something went wrong reading the file");
583 let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
584 let file = fs::read_to_string("resources/performance.csv")
585 .expect("Something went wrong reading the file");
586 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
587 let yhat = vec![0.5; y.len()];
588 let w = vec![1.; y.len()];
589 let (mut g, mut h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
590 let data = Matrix::new(&data_vec, 891, 5);
593 let splitter = MissingImputerSplitter {
594 l1: 0.0,
595 l2: 1.0,
596 max_delta_step: 0.,
597 gamma: 3.0,
598 min_leaf_weight: 1.0,
599 learning_rate: 0.3,
600 allow_missing_splits: true,
601 constraints_map: ConstraintMap::new(),
602 };
603 let mut tree = Tree::new();
604
605 let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
606 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
607 let mut rng = StdRng::seed_from_u64(0);
608 let (index, excluded) =
609 RandomSampler::new(0.5).sample(&mut rng, &data.index, &mut g, &mut h);
610 assert!(excluded.len() > 0);
611 let col_index: Vec<usize> = (0..data.cols).collect();
612 tree.fit(
613 &bdata,
614 index,
615 &col_index,
616 &b.cuts,
617 &g,
618 &h,
619 &splitter,
620 usize::MAX,
621 5,
622 true,
623 &SampleMethod::Random,
624 &GrowPolicy::DepthWise,
625 );
626 }
627
628 #[test]
629 fn test_tree_fit() {
630 let file = fs::read_to_string("resources/contiguous_no_missing.csv")
631 .expect("Something went wrong reading the file");
632 let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
633 let file = fs::read_to_string("resources/performance.csv")
634 .expect("Something went wrong reading the file");
635 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
636 let yhat = vec![0.5; y.len()];
637 let w = vec![1.; y.len()];
638 let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
639
640 let data = Matrix::new(&data_vec, 891, 5);
641 let splitter = MissingImputerSplitter {
642 l1: 0.0,
643 l2: 1.0,
644 max_delta_step: 0.,
645 gamma: 3.0,
646 min_leaf_weight: 1.0,
647 learning_rate: 0.3,
648 allow_missing_splits: true,
649 constraints_map: ConstraintMap::new(),
650 };
651 let mut tree = Tree::new();
652
653 let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
654 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
655 let col_index: Vec<usize> = (0..data.cols).collect();
656 tree.fit(
657 &bdata,
658 data.index.to_owned(),
659 &col_index,
660 &b.cuts,
661 &g,
662 &h,
663 &splitter,
664 usize::MAX,
665 5,
666 true,
667 &SampleMethod::None,
668 &GrowPolicy::DepthWise,
669 );
670
671 assert_eq!(25, tree.nodes.len());
675 let weights = tree.distribute_leaf_weights();
677 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
678 tree.predict_contributions_average(&data, &mut contribs, &weights, &f64::NAN);
679 let full_preds = tree.predict(&data, true, &f64::NAN);
680 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
681
682 let contribs_preds: Vec<f64> = contribs
683 .chunks(data.cols + 1)
684 .map(|i| i.iter().sum())
685 .collect();
686 println!("{:?}", &contribs[0..10]);
687 println!("{:?}", &contribs_preds[0..10]);
688
689 assert_eq!(contribs_preds.len(), full_preds.len());
690 for (i, j) in full_preds.iter().zip(contribs_preds) {
691 assert_eq!(precision_round(*i, 7), precision_round(j, 7));
692 }
693
694 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
696 tree.predict_contributions_weight(&data, &mut contribs, &f64::NAN);
697 let full_preds = tree.predict(&data, true, &f64::NAN);
698 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
699
700 let contribs_preds: Vec<f64> = contribs
701 .chunks(data.cols + 1)
702 .map(|i| i.iter().sum())
703 .collect();
704 println!("{:?}", &contribs[0..10]);
705 println!("{:?}", &contribs_preds[0..10]);
706
707 assert_eq!(contribs_preds.len(), full_preds.len());
708 for (i, j) in full_preds.iter().zip(contribs_preds) {
709 assert_eq!(precision_round(*i, 7), precision_round(j, 7));
710 }
711 }
712
713 #[test]
714 fn test_tree_colsample() {
715 let file = fs::read_to_string("resources/contiguous_no_missing.csv")
716 .expect("Something went wrong reading the file");
717 let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
718 let file = fs::read_to_string("resources/performance.csv")
719 .expect("Something went wrong reading the file");
720 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
721 let yhat = vec![0.5; y.len()];
722 let w = vec![1.; y.len()];
723 let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
724
725 let data = Matrix::new(&data_vec, 891, 5);
726 let splitter = MissingImputerSplitter {
727 l1: 0.0,
728 l2: 1.0,
729 max_delta_step: 0.,
730 gamma: 3.0,
731 min_leaf_weight: 1.0,
732 learning_rate: 0.3,
733 allow_missing_splits: true,
734 constraints_map: ConstraintMap::new(),
735 };
736 let mut tree = Tree::new();
737
738 let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
739 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
740 let col_index: Vec<usize> = vec![1, 3];
741 tree.fit(
742 &bdata,
743 data.index.to_owned(),
744 &col_index,
745 &b.cuts,
746 &g,
747 &h,
748 &splitter,
749 usize::MAX,
750 5,
751 false,
752 &SampleMethod::None,
753 &GrowPolicy::DepthWise,
754 );
755 for n in tree.nodes {
756 if !n.is_leaf {
757 assert!((n.split_feature == 1) || (n.split_feature == 3))
758 }
759 }
760 }
761
762 #[test]
763 fn test_tree_fit_monotone() {
764 let file = fs::read_to_string("resources/contiguous_no_missing.csv")
765 .expect("Something went wrong reading the file");
766 let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
767 let file = fs::read_to_string("resources/performance.csv")
768 .expect("Something went wrong reading the file");
769 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
770 let yhat = vec![0.5; y.len()];
771 let w = vec![1.; y.len()];
772 let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
773 println!("GRADIENT -- {:?}", h);
774
775 let data_ = Matrix::new(&data_vec, 891, 5);
776 let data = Matrix::new(data_.get_col(1), 891, 1);
777 let map = ConstraintMap::from([(0, Constraint::Negative)]);
778 let splitter = MissingImputerSplitter {
779 l1: 0.0,
780 l2: 1.0,
781 max_delta_step: 0.,
782 gamma: 0.0,
783 min_leaf_weight: 1.0,
784 learning_rate: 0.3,
785 allow_missing_splits: true,
786 constraints_map: map,
787 };
788 let mut tree = Tree::new();
789
790 let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
791 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
792 let col_index: Vec<usize> = (0..data.cols).collect();
793 tree.fit(
794 &bdata,
795 data.index.to_owned(),
796 &col_index,
797 &b.cuts,
798 &g,
799 &h,
800 &splitter,
801 usize::MAX,
802 5,
803 true,
804 &SampleMethod::None,
805 &GrowPolicy::DepthWise,
806 );
807
808 let mut pred_data_vec = data.get_col(0).to_owned();
810 pred_data_vec.sort_by(|a, b| a.partial_cmp(b).unwrap());
811 pred_data_vec.dedup();
812 let pred_data = Matrix::new(&pred_data_vec, pred_data_vec.len(), 1);
813
814 let preds = tree.predict(&pred_data, false, &f64::NAN);
815 let increasing = preds.windows(2).all(|a| a[0] >= a[1]);
816 assert!(increasing);
817
818 let weights = tree.distribute_leaf_weights();
819
820 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
822 tree.predict_contributions_average(&data, &mut contribs, &weights, &f64::NAN);
823 let full_preds = tree.predict(&data, true, &f64::NAN);
824 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
825 let contribs_preds: Vec<f64> = contribs
826 .chunks(data.cols + 1)
827 .map(|i| i.iter().sum())
828 .collect();
829 assert_eq!(contribs_preds.len(), full_preds.len());
830 for (i, j) in full_preds.iter().zip(contribs_preds) {
831 assert_eq!(precision_round(*i, 7), precision_round(j, 7));
832 }
833
834 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
836 tree.predict_contributions_weight(&data, &mut contribs, &f64::NAN);
837 let full_preds = tree.predict(&data, true, &f64::NAN);
838 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
839 let contribs_preds: Vec<f64> = contribs
840 .chunks(data.cols + 1)
841 .map(|i| i.iter().sum())
842 .collect();
843 assert_eq!(contribs_preds.len(), full_preds.len());
844 for (i, j) in full_preds.iter().zip(contribs_preds) {
845 assert_eq!(precision_round(*i, 7), precision_round(j, 7));
846 }
847 }
848
849 #[test]
850 fn test_tree_fit_lossguide() {
851 let file = fs::read_to_string("resources/contiguous_no_missing.csv")
852 .expect("Something went wrong reading the file");
853 let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
854 let file = fs::read_to_string("resources/performance.csv")
855 .expect("Something went wrong reading the file");
856 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
857 let yhat = vec![0.5; y.len()];
858 let w = vec![1.; y.len()];
859 let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, &w);
860
861 let data = Matrix::new(&data_vec, 891, 5);
862 let splitter = MissingImputerSplitter {
863 l1: 0.0,
864 l2: 1.0,
865 max_delta_step: 0.,
866 gamma: 3.0,
867 min_leaf_weight: 1.0,
868 learning_rate: 0.3,
869 allow_missing_splits: false,
870 constraints_map: ConstraintMap::new(),
871 };
872 let mut tree = Tree::new();
873
874 let b = bin_matrix(&data, &w, 300, f64::NAN).unwrap();
875 let bdata = Matrix::new(&b.binned_data, data.rows, data.cols);
876 let col_index: Vec<usize> = (0..data.cols).collect();
877 tree.fit(
878 &bdata,
879 data.index.to_owned(),
880 &col_index,
881 &b.cuts,
882 &g,
883 &h,
884 &splitter,
885 usize::MAX,
886 usize::MAX,
887 true,
888 &SampleMethod::None,
889 &GrowPolicy::LossGuide,
890 );
891
892 println!("{}", tree);
893 let weights = tree.distribute_leaf_weights();
898 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
899 tree.predict_contributions_average(&data, &mut contribs, &weights, &f64::NAN);
900 let full_preds = tree.predict(&data, true, &f64::NAN);
901 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
902
903 let contribs_preds: Vec<f64> = contribs
904 .chunks(data.cols + 1)
905 .map(|i| i.iter().sum())
906 .collect();
907 println!("{:?}", &contribs[0..10]);
908 println!("{:?}", &contribs_preds[0..10]);
909
910 assert_eq!(contribs_preds.len(), full_preds.len());
911 for (i, j) in full_preds.iter().zip(contribs_preds) {
912 assert_eq!(precision_round(*i, 7), precision_round(j, 7));
913 }
914
915 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
917 tree.predict_contributions_weight(&data, &mut contribs, &f64::NAN);
918 let full_preds = tree.predict(&data, true, &f64::NAN);
919 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
920
921 let contribs_preds: Vec<f64> = contribs
922 .chunks(data.cols + 1)
923 .map(|i| i.iter().sum())
924 .collect();
925 println!("{:?}", &contribs[0..10]);
926 println!("{:?}", &contribs_preds[0..10]);
927
928 assert_eq!(contribs_preds.len(), full_preds.len());
929 for (i, j) in full_preds.iter().zip(contribs_preds) {
930 assert_eq!(precision_round(*i, 7), precision_round(j, 7));
931 }
932 }
933}