forust_ml/
gradientbooster.rs

1use crate::binning::bin_matrix;
2use crate::constraints::ConstraintMap;
3use crate::data::{Matrix, RowMajorMatrix};
4use crate::errors::ForustError;
5use crate::metric::{is_comparison_better, metric_callables, Metric, MetricFn};
6use crate::objective::{
7    calc_init_callables, gradient_hessian_callables, LogLoss, ObjectiveFunction, ObjectiveType,
8    SquaredLoss,
9};
10use crate::sampler::{GossSampler, RandomSampler, SampleMethod, Sampler};
11use crate::shapley::predict_contributions_row_shapley;
12use crate::splitter::{MissingBranchSplitter, MissingImputerSplitter, Splitter};
13use crate::tree::Tree;
14use crate::utils::{
15    fmt_vec_output, odds, validate_not_nan_vec, validate_positive_float_field,
16    validate_positive_not_nan_vec,
17};
18use log::info;
19use rand::rngs::StdRng;
20use rand::seq::IteratorRandom;
21use rand::SeedableRng;
22use rayon::prelude::*;
23use serde::{Deserialize, Deserializer, Serialize};
24use std::collections::{HashMap, HashSet};
25use std::fs;
26
27pub type EvaluationData<'a> = (Matrix<'a, f64>, &'a [f64], &'a [f64]);
28pub type TrainingEvaluationData<'a> = (&'a Matrix<'a, f64>, &'a [f64], &'a [f64], Vec<f64>);
29type ImportanceFn = fn(&Tree, &mut HashMap<usize, (f32, usize)>);
30
31#[derive(Serialize, Deserialize)]
32pub enum GrowPolicy {
33    DepthWise,
34    LossGuide,
35}
36
37#[derive(Serialize, Deserialize)]
38pub enum ContributionsMethod {
39    /// This method will use the internal leaf weights, to calculate the contributions. This is the same as what is described by Saabas [here](https://blog.datadive.net/interpreting-random-forests/).
40    Weight,
41    /// If this option is specified, the average internal node values are calculated, this is equivalent to the `approx_contribs` parameter in XGBoost.
42    Average,
43    /// This method will calculate contributions by subtracting the weight of the node the record will travel down by the weight of the other non-missing branch. This method does not have the property where the contributions summed is equal to the final prediction of the model.
44    BranchDifference,
45    /// This method will calculate contributions by subtracting the weight of the node the record will travel down by the mid-point between the right and left node weighted by the cover of each node. This method does not have the property where the contributions summed is equal to the final prediction of the model.
46    MidpointDifference,
47    /// This method will calculate contributions by subtracting the weight of the node the record will travel down by the weight of the node with the largest cover (the mode node). This method does not have the property where the contributions summed is equal to the final prediction of the model.
48    ModeDifference,
49    /// This method is only valid when the objective type is set to "LogLoss". This method will calculate contributions as the change in a records probability of being 1 moving from a parent node to a child node. The sum of the returned contributions matrix, will be equal to the probability a record will be 1. For example, given a model, `model.predict_contributions(X, method="ProbabilityChange") == 1 / (1 + np.exp(-model.predict(X)))`
50    ProbabilityChange,
51    /// This method computes the Shapley values for each record, and feature.
52    Shapley,
53}
54
55/// Method to calculate variable importance.
56#[derive(Serialize, Deserialize)]
57pub enum ImportanceMethod {
58    /// The number of times a feature is used to split the data across all trees.
59    Weight,
60    /// The average split gain across all splits the feature is used in.
61    Gain,
62    /// The average coverage across all splits the feature is used in.
63    Cover,
64    /// The total gain across all splits the feature is used in.
65    TotalGain,
66    /// The total coverage across all splits the feature is used in.
67    TotalCover,
68}
69
70#[derive(Serialize, Deserialize, Clone, Copy)]
71pub enum MissingNodeTreatment {
72    /// Calculate missing node weight values without any constraints.
73    None,
74    /// Assign the weight of the missing node to that of the parent.
75    AssignToParent,
76    /// After training each tree, starting from the bottom of the tree, assign the missing node weight to the weighted average of the left and right child nodes. Next assign the parent to the weighted average of the children nodes. This is performed recursively up through the entire tree. This is performed as a post processing step on each tree after it is built, and prior to updating the predictions for which to train the next tree.
77    AverageLeafWeight,
78    /// Set the missing node to be equal to the weighted average weight of the left and the right nodes.
79    AverageNodeWeight,
80}
81
82/// Gradient Booster object
83#[derive(Deserialize, Serialize)]
84pub struct GradientBooster {
85    /// The name of objective function used to optimize.
86    /// Valid options include "LogLoss" to use logistic loss as the objective function,
87    /// or "SquaredLoss" to use Squared Error as the objective function.
88    pub objective_type: ObjectiveType,
89    /// Total number of trees to train in the ensemble.
90    pub iterations: usize,
91    /// Step size to use at each iteration. Each
92    /// leaf weight is multiplied by this number. The smaller the value, the more
93    /// conservative the weights will be.
94    pub learning_rate: f32,
95    /// Maximum depth of an individual tree. Valid values are 0 to infinity.
96    pub max_depth: usize,
97    /// Maximum number of leaves allowed on a tree. Valid values
98    /// are 0 to infinity. This is the total number of final nodes.
99    pub max_leaves: usize,
100    /// L1 regularization term applied to the weights of the tree. Valid values
101    /// are 0 to infinity. 0 Means no regularization applied.
102    #[serde(default = "default_l1")]
103    pub l1: f32,
104    /// L2 regularization term applied to the weights of the tree. Valid values
105    /// are 0 to infinity.
106    pub l2: f32,
107    /// The minimum amount of loss required to further split a node.
108    /// Valid values are 0 to infinity.
109    pub gamma: f32,
110    /// Maximum delta step allowed at each leaf. This is the maximum magnitude a leaf can take. Setting to 0 results in no constrain.
111    #[serde(default = "default_max_delta_step")]
112    pub max_delta_step: f32,
113    /// Minimum sum of the hessian values of the loss function
114    /// required to be in a node.
115    pub min_leaf_weight: f32,
116    /// The initial prediction value of the model.
117    pub base_score: f64,
118    /// Number of bins to calculate to partition the data. Setting this to
119    /// a smaller number, will result in faster training time, while potentially sacrificing
120    /// accuracy. If there are more bins, than unique values in a column, all unique values
121    /// will be used.
122    pub nbins: u16,
123    pub parallel: bool,
124    /// Should the algorithm allow splits that completed seperate out missing
125    /// and non-missing values, in the case where `create_missing_branch` is false. When `create_missing_branch`
126    /// is true, setting this to true will result in the missin branch being further split.
127    pub allow_missing_splits: bool,
128    /// Constraints that are used to enforce a specific relationship
129    /// between the training features and the target variable.
130    pub monotone_constraints: Option<ConstraintMap>,
131    /// Percent of records to randomly sample at each iteration when training a tree.
132    pub subsample: f32,
133    /// Used only in goss. The retain ratio of large gradient data.
134    #[serde(default = "default_top_rate")]
135    pub top_rate: f64,
136    /// Used only in goss. the retain ratio of small gradient data.
137    #[serde(default = "default_other_rate")]
138    pub other_rate: f64,
139    /// Specify the fraction of columns that should be sampled at each iteration, valid values are in the range (0.0,1.0].
140    #[serde(default = "default_colsample_bytree")]
141    pub colsample_bytree: f64,
142    /// Integer value used to seed any randomness used in the algorithm.
143    pub seed: u64,
144    /// Value to consider missing.
145    #[serde(deserialize_with = "parse_missing")]
146    pub missing: f64,
147    /// Should missing be split out it's own separate branch?
148    pub create_missing_branch: bool,
149    /// Specify the method that records should be sampled when training?
150    #[serde(default = "default_sample_method")]
151    pub sample_method: SampleMethod,
152    /// Growth policy to use when training a tree, this is how the next node is selected.
153    #[serde(default = "default_grow_policy")]
154    pub grow_policy: GrowPolicy,
155    /// Define the evaluation metric to record at each iterations.
156    #[serde(default = "default_evaluation_metric")]
157    pub evaluation_metric: Option<Metric>,
158    /// Number of rounds where the evaluation metric value must improve in
159    /// to keep training.
160    #[serde(default = "default_early_stopping_rounds")]
161    pub early_stopping_rounds: Option<usize>,
162    /// If this is specified, the base_score will be calculated using the sample_weight and y data in accordance with the requested objective_type.
163    #[serde(default = "default_initialize_base_score")]
164    pub initialize_base_score: bool,
165    /// A set of features for which the missing node will always be terminated, even
166    /// if `allow_missing_splits` is set to true. This value is only valid if
167    /// `create_missing_branch` is also True.
168    #[serde(default = "default_terminate_missing_features")]
169    pub terminate_missing_features: HashSet<usize>,
170    /// A matrix of the evaluation history on the evaluation datasets.
171    #[serde(default = "default_evaluation_history")]
172    pub evaluation_history: Option<RowMajorMatrix<f64>>,
173    #[serde(default = "default_best_iteration")]
174    pub best_iteration: Option<usize>,
175    /// Number of trees to use when predicting,
176    /// defaults to best_iteration if this is defined.
177    #[serde(default = "default_prediction_iteration")]
178    pub prediction_iteration: Option<usize>,
179    /// How the missing nodes weights should be treated at training time.
180    #[serde(default = "default_missing_node_treatment")]
181    pub missing_node_treatment: MissingNodeTreatment,
182    /// Should the model be trained showing output.
183    #[serde(default = "default_log_iterations")]
184    pub log_iterations: usize,
185    /// Should the children nodes contain the parent node in their bounds, setting this to true, will result in no children being created that result in the higher and lower child values both being greater than, or less than the parent weight.
186    #[serde(default = "default_force_children_to_bound_parent")]
187    pub force_children_to_bound_parent: bool,
188    // Members internal to the booster object, and not parameters set by the user.
189    // Trees is public, just to interact with it directly in the python wrapper.
190    pub trees: Vec<Tree>,
191    metadata: HashMap<String, String>,
192}
193
194fn default_l1() -> f32 {
195    0.0
196}
197fn default_max_delta_step() -> f32 {
198    0.0
199}
200
201fn default_initialize_base_score() -> bool {
202    false
203}
204
205fn default_grow_policy() -> GrowPolicy {
206    GrowPolicy::DepthWise
207}
208
209fn default_top_rate() -> f64 {
210    0.1
211}
212fn default_other_rate() -> f64 {
213    0.2
214}
215fn default_sample_method() -> SampleMethod {
216    SampleMethod::None
217}
218fn default_evaluation_metric() -> Option<Metric> {
219    None
220}
221fn default_early_stopping_rounds() -> Option<usize> {
222    None
223}
224fn default_evaluation_history() -> Option<RowMajorMatrix<f64>> {
225    None
226}
227fn default_best_iteration() -> Option<usize> {
228    None
229}
230fn default_prediction_iteration() -> Option<usize> {
231    None
232}
233fn default_terminate_missing_features() -> HashSet<usize> {
234    HashSet::new()
235}
236fn default_colsample_bytree() -> f64 {
237    1.0
238}
239fn default_missing_node_treatment() -> MissingNodeTreatment {
240    MissingNodeTreatment::AssignToParent
241}
242
243fn default_log_iterations() -> usize {
244    0
245}
246fn default_force_children_to_bound_parent() -> bool {
247    false
248}
249
250fn parse_missing<'de, D>(d: D) -> Result<f64, D::Error>
251where
252    D: Deserializer<'de>,
253{
254    Deserialize::deserialize(d).map(|x: Option<_>| x.unwrap_or(f64::NAN))
255}
256
257impl Default for GradientBooster {
258    fn default() -> Self {
259        Self::new(
260            ObjectiveType::LogLoss,
261            10,
262            0.3,
263            5,
264            usize::MAX,
265            0.,
266            1.,
267            0.,
268            0.,
269            1.,
270            0.5,
271            256,
272            true,
273            true,
274            None,
275            1.,
276            0.1,
277            0.2,
278            1.0,
279            0,
280            f64::NAN,
281            false,
282            SampleMethod::None,
283            GrowPolicy::DepthWise,
284            None,
285            None,
286            true,
287            HashSet::new(),
288            MissingNodeTreatment::AssignToParent,
289            0,
290            false,
291        )
292        .unwrap()
293    }
294}
295
296impl GradientBooster {
297    /// Gradient Booster object
298    ///
299    /// * `objective_type` - The name of objective function used to optimize.
300    ///   Valid options include "LogLoss" to use logistic loss as the objective function,
301    ///   or "SquaredLoss" to use Squared Error as the objective function.
302    /// * `iterations` - Total number of trees to train in the ensemble.
303    /// * `learning_rate` - Step size to use at each iteration. Each
304    ///   leaf weight is multiplied by this number. The smaller the value, the more
305    ///   conservative the weights will be.
306    /// * `max_depth` - Maximum depth of an individual tree. Valid values
307    ///   are 0 to infinity.
308    /// * `max_leaves` - Maximum number of leaves allowed on a tree. Valid values
309    ///   are 0 to infinity. This is the total number of final nodes.
310    /// * `l2` - L2 regularization term applied to the weights of the tree. Valid values
311    ///   are 0 to infinity.
312    /// * `gamma` - The minimum amount of loss required to further split a node.
313    ///   Valid values are 0 to infinity.
314    /// * `min_leaf_weight` - Minimum sum of the hessian values of the loss function
315    ///   required to be in a node.
316    /// * `base_score` - The initial prediction value of the model. If set to None the parameter `initialize_base_score` will automatically be set to `true`, in which case the base score will be chosen based on the objective function at fit time.
317    /// * `nbins` - Number of bins to calculate to partition the data. Setting this to
318    ///   a smaller number, will result in faster training time, while potentially sacrificing
319    ///   accuracy. If there are more bins, than unique values in a column, all unique values
320    ///   will be used.
321    /// * `parallel` - Should the algorithm be run in parallel?
322    /// * `allow_missing_splits` - Should the algorithm allow splits that completed seperate out missing
323    /// and non-missing values, in the case where `create_missing_branch` is false. When `create_missing_branch`
324    /// is true, setting this to true will result in the missin branch being further split.
325    /// * `monotone_constraints` - Constraints that are used to enforce a specific relationship
326    ///   between the training features and the target variable.
327    /// * `subsample` - Percent of records to randomly sample at each iteration when training a tree.
328    /// * `top_rate` - Used only in goss. The retain ratio of large gradient data.
329    /// * `other_rate` - Used only in goss. the retain ratio of small gradient data.
330    /// * `colsample_bytree` - Specify the fraction of columns that should be sampled at each iteration, valid values are in the range (0.0,1.0].
331    /// * `seed` - Integer value used to seed any randomness used in the algorithm.
332    /// * `missing` - Value to consider missing.
333    /// * `create_missing_branch` - Should missing be split out it's own separate branch?
334    /// * `sample_method` - Specify the method that records should be sampled when training?
335    /// * `evaluation_metric` - Define the evaluation metric to record at each iterations.
336    /// * `early_stopping_rounds` - Number of rounds that must
337    /// * `initialize_base_score` - If this is specified, the base_score will be calculated using the sample_weight and y data in accordance with the requested objective_type.
338    /// * `missing_node_treatment` - specify how missing nodes should be handled during training.
339    /// * `log_iterations` - Setting to a value (N) other than zero will result in information being logged about ever N iterations.
340    #[allow(clippy::too_many_arguments)]
341    pub fn new(
342        objective_type: ObjectiveType,
343        iterations: usize,
344        learning_rate: f32,
345        max_depth: usize,
346        max_leaves: usize,
347        l1: f32,
348        l2: f32,
349        gamma: f32,
350        max_delta_step: f32,
351        min_leaf_weight: f32,
352        base_score: f64,
353        nbins: u16,
354        parallel: bool,
355        allow_missing_splits: bool,
356        monotone_constraints: Option<ConstraintMap>,
357        subsample: f32,
358        top_rate: f64,
359        other_rate: f64,
360        colsample_bytree: f64,
361        seed: u64,
362        missing: f64,
363        create_missing_branch: bool,
364        sample_method: SampleMethod,
365        grow_policy: GrowPolicy,
366        evaluation_metric: Option<Metric>,
367        early_stopping_rounds: Option<usize>,
368        initialize_base_score: bool,
369        terminate_missing_features: HashSet<usize>,
370        missing_node_treatment: MissingNodeTreatment,
371        log_iterations: usize,
372        force_children_to_bound_parent: bool,
373    ) -> Result<Self, ForustError> {
374        let booster = GradientBooster {
375            objective_type,
376            iterations,
377            learning_rate,
378            max_depth,
379            max_leaves,
380            l1,
381            l2,
382            gamma,
383            max_delta_step,
384            min_leaf_weight,
385            base_score,
386            nbins,
387            parallel,
388            allow_missing_splits,
389            monotone_constraints,
390            subsample,
391            top_rate,
392            other_rate,
393            colsample_bytree,
394            seed,
395            missing,
396            create_missing_branch,
397            sample_method,
398            grow_policy,
399            evaluation_metric,
400            early_stopping_rounds,
401            initialize_base_score,
402            terminate_missing_features,
403            evaluation_history: None,
404            best_iteration: None,
405            prediction_iteration: None,
406            missing_node_treatment,
407            log_iterations,
408            force_children_to_bound_parent,
409            trees: Vec::new(),
410            metadata: HashMap::new(),
411        };
412        booster.validate_parameters()?;
413        Ok(booster)
414    }
415
416    fn validate_parameters(&self) -> Result<(), ForustError> {
417        validate_positive_float_field!(self.learning_rate);
418        validate_positive_float_field!(self.l1);
419        validate_positive_float_field!(self.l2);
420        validate_positive_float_field!(self.gamma);
421        validate_positive_float_field!(self.max_delta_step);
422        validate_positive_float_field!(self.min_leaf_weight);
423        validate_positive_float_field!(self.subsample);
424        validate_positive_float_field!(self.top_rate);
425        validate_positive_float_field!(self.other_rate);
426        validate_positive_float_field!(self.colsample_bytree);
427        Ok(())
428    }
429
430    /// Fit the gradient booster on a provided dataset.
431    ///
432    /// * `data` -  Either a pandas DataFrame, or a 2 dimensional numpy array.
433    /// * `y` - Either a pandas Series, or a 1 dimensional numpy array.
434    /// * `sample_weight` - Instance weights to use when
435    /// training the model. If None is passed, a weight of 1 will be used for every record.
436    pub fn fit(
437        &mut self,
438        data: &Matrix<f64>,
439        y: &[f64],
440        sample_weight: &[f64],
441        evaluation_data: Option<Vec<EvaluationData>>,
442    ) -> Result<(), ForustError> {
443        // Validate inputs
444        validate_not_nan_vec(y, "y".to_string())?;
445        validate_positive_not_nan_vec(sample_weight, "sample_weight".to_string())?;
446        if let Some(eval_data) = &evaluation_data {
447            for (i, (_, eval_y, eval_sample_weight)) in eval_data.iter().enumerate() {
448                validate_not_nan_vec(eval_y, format!("eval set {} y", i).to_string())?;
449                validate_positive_not_nan_vec(
450                    eval_sample_weight,
451                    format!("eval set {} sample_weight", i).to_string(),
452                )?;
453            }
454        }
455
456        let constraints_map = self
457            .monotone_constraints
458            .as_ref()
459            .unwrap_or(&ConstraintMap::new())
460            .to_owned();
461        if self.create_missing_branch {
462            let splitter = MissingBranchSplitter {
463                l1: self.l1,
464                l2: self.l2,
465                max_delta_step: self.max_delta_step,
466                gamma: self.gamma,
467                min_leaf_weight: self.min_leaf_weight,
468                learning_rate: self.learning_rate,
469                allow_missing_splits: self.allow_missing_splits,
470                constraints_map,
471                terminate_missing_features: self.terminate_missing_features.clone(),
472                missing_node_treatment: self.missing_node_treatment,
473                force_children_to_bound_parent: self.force_children_to_bound_parent,
474            };
475            self.fit_trees(y, sample_weight, data, &splitter, evaluation_data)?;
476        } else {
477            let splitter = MissingImputerSplitter {
478                l1: self.l1,
479                l2: self.l2,
480                max_delta_step: self.max_delta_step,
481                gamma: self.gamma,
482                min_leaf_weight: self.min_leaf_weight,
483                learning_rate: self.learning_rate,
484                allow_missing_splits: self.allow_missing_splits,
485                constraints_map,
486            };
487            self.fit_trees(y, sample_weight, data, &splitter, evaluation_data)?;
488        };
489
490        Ok(())
491    }
492
493    fn sample_index(
494        &self,
495        rng: &mut StdRng,
496        index: &[usize],
497        grad: &mut [f32],
498        hess: &mut [f32],
499    ) -> (Vec<usize>, Vec<usize>) {
500        match self.sample_method {
501            SampleMethod::None => (index.to_owned(), Vec::new()),
502            SampleMethod::Random => {
503                RandomSampler::new(self.subsample).sample(rng, index, grad, hess)
504            }
505            SampleMethod::Goss => {
506                GossSampler::new(self.top_rate, self.other_rate).sample(rng, index, grad, hess)
507            }
508        }
509    }
510
511    fn get_metric_fn(&self) -> (MetricFn, bool) {
512        let metric = match &self.evaluation_metric {
513            None => match self.objective_type {
514                ObjectiveType::LogLoss => LogLoss::default_metric(),
515                ObjectiveType::SquaredLoss => SquaredLoss::default_metric(),
516            },
517            Some(v) => *v,
518        };
519        metric_callables(&metric)
520    }
521
522    fn reset(&mut self) {
523        self.trees = Vec::new();
524        self.evaluation_history = None;
525        self.best_iteration = None;
526        self.prediction_iteration = None;
527    }
528
529    fn fit_trees<T: Splitter>(
530        &mut self,
531        y: &[f64],
532        sample_weight: &[f64],
533        data: &Matrix<f64>,
534        splitter: &T,
535        evaluation_data: Option<Vec<EvaluationData>>,
536    ) -> Result<(), ForustError> {
537        // Is this a booster that has already been fit? If it is, reset the trees.
538        // In the future we could continue training.
539        if !self.trees.is_empty() {
540            self.reset()
541        }
542
543        let mut rng = StdRng::seed_from_u64(self.seed);
544
545        if self.initialize_base_score {
546            self.base_score = calc_init_callables(&self.objective_type)(y, sample_weight);
547        }
548
549        let mut yhat = vec![self.base_score; y.len()];
550
551        let calc_grad_hess = gradient_hessian_callables(&self.objective_type);
552        let (mut grad, mut hess) = calc_grad_hess(y, &yhat, sample_weight);
553
554        // Generate binned data
555        // TODO
556        // In scikit-learn, they sample 200_000 records for generating the bins.
557        // we could consider that, especially if this proved to be a large bottleneck...
558        let binned_data = bin_matrix(data, sample_weight, self.nbins, self.missing)?;
559        let bdata = Matrix::new(&binned_data.binned_data, data.rows, data.cols);
560
561        // Create the predictions, saving them with the evaluation data.
562        let mut evaluation_sets: Option<Vec<TrainingEvaluationData>> =
563            evaluation_data.as_ref().map(|evals| {
564                evals
565                    .iter()
566                    .map(|(d, y, w)| (d, *y, *w, vec![self.base_score; y.len()]))
567                    .collect()
568            });
569
570        let mut best_metric: Option<f64> = None;
571
572        // This will always be false, unless early stopping rounds are used.
573        let mut stop_early = false;
574        let col_index: Vec<usize> = (0..data.cols).collect();
575        for i in 0..self.iterations {
576            let verbose = if self.log_iterations == 0 {
577                false
578            } else {
579                i % self.log_iterations == 0
580            };
581            // We will eventually use the excluded index.
582            let (chosen_index, _excluded_index) =
583                self.sample_index(&mut rng, &data.index, &mut grad, &mut hess);
584            let mut tree = Tree::new();
585
586            // If we are doing any column sampling...
587            let colsample_index: Vec<usize> = if self.colsample_bytree == 1.0 {
588                Vec::new()
589            } else {
590                let amount = ((col_index.len() as f64) * self.colsample_bytree).floor() as usize;
591                let mut v: Vec<usize> = col_index
592                    .iter()
593                    .choose_multiple(&mut rng, amount)
594                    .iter()
595                    .map(|i| **i)
596                    .collect();
597                v.sort();
598                v
599            };
600
601            let fit_col_index = if self.colsample_bytree == 1.0 {
602                &col_index
603            } else {
604                &colsample_index
605            };
606
607            tree.fit(
608                &bdata,
609                chosen_index,
610                fit_col_index,
611                &binned_data.cuts,
612                &grad,
613                &hess,
614                splitter,
615                self.max_leaves,
616                self.max_depth,
617                self.parallel,
618                &self.sample_method,
619                &self.grow_policy,
620            );
621
622            self.update_predictions_inplace(&mut yhat, &tree, data);
623
624            // Update Evaluation data, if it's needed.
625            if let Some(eval_sets) = &mut evaluation_sets {
626                if self.evaluation_history.is_none() {
627                    self.evaluation_history =
628                        Some(RowMajorMatrix::new(Vec::new(), 0, eval_sets.len()));
629                }
630                let mut metrics: Vec<f64> = Vec::new();
631                let n_eval_sets = eval_sets.len();
632                for (eval_i, (data, y, w, yhat)) in eval_sets.iter_mut().enumerate() {
633                    self.update_predictions_inplace(yhat, &tree, data);
634                    let (metric_fn, maximize) = self.get_metric_fn();
635                    let m = metric_fn(y, yhat, w);
636                    // If early stopping rounds are defined, and this is the last
637                    // eval dataset, check if we want to stop or keep training.
638                    // Updating to align with XGBoost, originally we were using the first
639                    // dataset, but switching to use the last.
640                    if (eval_i + 1) == n_eval_sets {
641                        if let Some(early_stopping_rounds) = self.early_stopping_rounds {
642                            // If best metric is undefined, this must be the first
643                            // iteration...
644                            best_metric = match best_metric {
645                                None => {
646                                    self.update_best_iteration(i);
647                                    Some(m)
648                                }
649                                // Otherwise the best could be farther back.
650                                Some(v) => {
651                                    // We have reached a new best value...
652                                    if is_comparison_better(v, m, maximize) {
653                                        self.update_best_iteration(i);
654                                        Some(m)
655                                    } else {
656                                        // Previous value was better.
657                                        if let Some(best_iteration) = self.best_iteration {
658                                            if i - best_iteration >= early_stopping_rounds {
659                                                // If any logging is requested, print this message.
660                                                if self.log_iterations > 0 {
661                                                    info!("Stopping early at iteration {} with metric value {}", i, m)
662                                                }
663                                                stop_early = true;
664                                            }
665                                        }
666                                        Some(v)
667                                    }
668                                }
669                            };
670                        }
671                    }
672                    metrics.push(m);
673                }
674                if verbose {
675                    info!(
676                        "Iteration {} evaluation data values: {}",
677                        i,
678                        fmt_vec_output(&metrics)
679                    );
680                }
681                if let Some(history) = &mut self.evaluation_history {
682                    history.append_row(metrics);
683                }
684            }
685            self.trees.push(tree);
686
687            // Did we trigger the early stopping rounds criteria?
688            if stop_early {
689                break;
690            }
691
692            (grad, hess) = calc_grad_hess(y, &yhat, sample_weight);
693            if verbose {
694                info!("Completed iteration {} of {}", i, self.iterations);
695            }
696        }
697        if self.log_iterations > 0 {
698            info!(
699                "Finished training booster with {} iterations.",
700                self.trees.len()
701            );
702        }
703        Ok(())
704    }
705
706    fn update_best_iteration(&mut self, i: usize) {
707        self.best_iteration = Some(i);
708        self.prediction_iteration = Some(i + 1);
709    }
710
711    fn update_predictions_inplace(&self, yhat: &mut [f64], tree: &Tree, data: &Matrix<f64>) {
712        let preds = tree.predict(data, self.parallel, &self.missing);
713        yhat.iter_mut().zip(preds).for_each(|(i, j)| *i += j);
714    }
715
716    /// Fit the gradient booster on a provided dataset without any weights.
717    ///
718    /// * `data` -  Either a pandas DataFrame, or a 2 dimensional numpy array.
719    /// * `y` - Either a pandas Series, or a 1 dimensional numpy array.
720    pub fn fit_unweighted(
721        &mut self,
722        data: &Matrix<f64>,
723        y: &[f64],
724        evaluation_data: Option<Vec<EvaluationData>>,
725    ) -> Result<(), ForustError> {
726        let sample_weight = vec![1.0; data.rows];
727        self.fit(data, y, &sample_weight, evaluation_data)
728    }
729
730    /// Generate predictions on data using the gradient booster.
731    ///
732    /// * `data` -  Either a pandas DataFrame, or a 2 dimensional numpy array.
733    pub fn predict(&self, data: &Matrix<f64>, parallel: bool) -> Vec<f64> {
734        let mut init_preds = vec![self.base_score; data.rows];
735        self.get_prediction_trees().iter().for_each(|tree| {
736            for (p_, val) in init_preds
737                .iter_mut()
738                .zip(tree.predict(data, parallel, &self.missing))
739            {
740                *p_ += val;
741            }
742        });
743        init_preds
744    }
745
746    /// Predict the leaf Indexes, this returns a vector of length N records * N Trees
747    pub fn predict_leaf_indices(&self, data: &Matrix<f64>) -> Vec<usize> {
748        self.get_prediction_trees()
749            .iter()
750            .flat_map(|tree| tree.predict_leaf_indices(data, &self.missing))
751            .collect()
752    }
753
754    /// Predict the contributions matrix for the provided dataset.
755    pub fn predict_contributions(
756        &self,
757        data: &Matrix<f64>,
758        method: ContributionsMethod,
759        parallel: bool,
760    ) -> Vec<f64> {
761        match method {
762            ContributionsMethod::Average => self.predict_contributions_average(data, parallel),
763            ContributionsMethod::ProbabilityChange => {
764                match self.objective_type {
765                    ObjectiveType::LogLoss => {},
766                    _ => panic!("ProbabilityChange contributions method is only valid when LogLoss objective is used.")
767                }
768                self.predict_contributions_probability_change(data, parallel)
769            }
770            _ => self.predict_contributions_tree_alone(data, parallel, method),
771        }
772    }
773
774    // All of the contribution calculation methods, except for average are calculated
775    // using just the model, so we don't need to have separate methods, we can instead
776    // just have this one method, that dispatches to each one respectively.
777    fn predict_contributions_tree_alone(
778        &self,
779        data: &Matrix<f64>,
780        parallel: bool,
781        method: ContributionsMethod,
782    ) -> Vec<f64> {
783        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
784
785        // Add the bias term to every bias value...
786        let bias_idx = data.cols + 1;
787        contribs
788            .iter_mut()
789            .skip(bias_idx - 1)
790            .step_by(bias_idx)
791            .for_each(|v| *v += self.base_score);
792
793        let row_pred_fn = match method {
794            ContributionsMethod::Weight => Tree::predict_contributions_row_weight,
795            ContributionsMethod::BranchDifference => {
796                Tree::predict_contributions_row_branch_difference
797            }
798            ContributionsMethod::MidpointDifference => {
799                Tree::predict_contributions_row_midpoint_difference
800            }
801            ContributionsMethod::ModeDifference => Tree::predict_contributions_row_mode_difference,
802            ContributionsMethod::Shapley => predict_contributions_row_shapley,
803            ContributionsMethod::Average | ContributionsMethod::ProbabilityChange => unreachable!(),
804        };
805        // Clean this up..
806        // materializing a row, and then passing that to all of the
807        // trees seems to be the fastest approach (5X faster), we should test
808        // something like this for normal predictions.
809        if parallel {
810            data.index
811                .par_iter()
812                .zip(contribs.par_chunks_mut(data.cols + 1))
813                .for_each(|(row, c)| {
814                    let r_ = data.get_row(*row);
815                    self.get_prediction_trees().iter().for_each(|t| {
816                        row_pred_fn(t, &r_, c, &self.missing);
817                    });
818                });
819        } else {
820            data.index
821                .iter()
822                .zip(contribs.chunks_mut(data.cols + 1))
823                .for_each(|(row, c)| {
824                    let r_ = data.get_row(*row);
825                    self.get_prediction_trees().iter().for_each(|t| {
826                        row_pred_fn(t, &r_, c, &self.missing);
827                    });
828                });
829        }
830
831        contribs
832    }
833
834    /// Get the a reference to the trees for predicting, ensureing that the right number of
835    /// trees are used.
836    fn get_prediction_trees(&self) -> &[Tree] {
837        let n_iterations = self.prediction_iteration.unwrap_or(self.trees.len());
838        &self.trees[..n_iterations]
839    }
840
841    /// Generate predictions on data using the gradient booster.
842    /// This is equivalent to the XGBoost predict contributions with approx_contribs
843    ///
844    /// * `data` -  Either a pandas DataFrame, or a 2 dimensional numpy array.
845    fn predict_contributions_average(&self, data: &Matrix<f64>, parallel: bool) -> Vec<f64> {
846        let weights: Vec<Vec<f64>> = if parallel {
847            self.get_prediction_trees()
848                .par_iter()
849                .map(|t| t.distribute_leaf_weights())
850                .collect()
851        } else {
852            self.get_prediction_trees()
853                .iter()
854                .map(|t| t.distribute_leaf_weights())
855                .collect()
856        };
857        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
858
859        // Add the bias term to every bias value...
860        let bias_idx = data.cols + 1;
861        contribs
862            .iter_mut()
863            .skip(bias_idx - 1)
864            .step_by(bias_idx)
865            .for_each(|v| *v += self.base_score);
866
867        // Clean this up..
868        // materializing a row, and then passing that to all of the
869        // trees seems to be the fastest approach (5X faster), we should test
870        // something like this for normal predictions.
871        if parallel {
872            data.index
873                .par_iter()
874                .zip(contribs.par_chunks_mut(data.cols + 1))
875                .for_each(|(row, c)| {
876                    let r_ = data.get_row(*row);
877                    self.get_prediction_trees()
878                        .iter()
879                        .zip(weights.iter())
880                        .for_each(|(t, w)| {
881                            t.predict_contributions_row_average(&r_, c, w, &self.missing);
882                        });
883                });
884        } else {
885            data.index
886                .iter()
887                .zip(contribs.chunks_mut(data.cols + 1))
888                .for_each(|(row, c)| {
889                    let r_ = data.get_row(*row);
890                    self.get_prediction_trees()
891                        .iter()
892                        .zip(weights.iter())
893                        .for_each(|(t, w)| {
894                            t.predict_contributions_row_average(&r_, c, w, &self.missing);
895                        });
896                });
897        }
898
899        contribs
900    }
901
902    fn predict_contributions_probability_change(
903        &self,
904        data: &Matrix<f64>,
905        parallel: bool,
906    ) -> Vec<f64> {
907        let mut contribs = vec![0.; (data.cols + 1) * data.rows];
908        let bias_idx = data.cols + 1;
909        contribs
910            .iter_mut()
911            .skip(bias_idx - 1)
912            .step_by(bias_idx)
913            .for_each(|v| *v += odds(self.base_score));
914
915        if parallel {
916            data.index
917                .par_iter()
918                .zip(contribs.par_chunks_mut(data.cols + 1))
919                .for_each(|(row, c)| {
920                    let r_ = data.get_row(*row);
921                    self.get_prediction_trees()
922                        .iter()
923                        .fold(self.base_score, |acc, t| {
924                            t.predict_contributions_row_probability_change(
925                                &r_,
926                                c,
927                                &self.missing,
928                                acc,
929                            )
930                        });
931                });
932        } else {
933            data.index
934                .iter()
935                .zip(contribs.chunks_mut(data.cols + 1))
936                .for_each(|(row, c)| {
937                    let r_ = data.get_row(*row);
938                    self.get_prediction_trees()
939                        .iter()
940                        .fold(self.base_score, |acc, t| {
941                            t.predict_contributions_row_probability_change(
942                                &r_,
943                                c,
944                                &self.missing,
945                                acc,
946                            )
947                        });
948                });
949        }
950        contribs
951    }
952
953    /// Given a value, return the partial dependence value of that value for that
954    /// feature in the model.
955    ///
956    /// * `feature` - The index of the feature.
957    /// * `value` - The value for which to calculate the partial dependence.
958    pub fn value_partial_dependence(&self, feature: usize, value: f64) -> f64 {
959        let pd: f64 = if self.parallel {
960            self.get_prediction_trees()
961                .par_iter()
962                .map(|t| t.value_partial_dependence(feature, value, &self.missing))
963                .sum()
964        } else {
965            self.get_prediction_trees()
966                .iter()
967                .map(|t| t.value_partial_dependence(feature, value, &self.missing))
968                .sum()
969        };
970        pd + self.base_score
971    }
972
973    /// Calculate feature importance measure for the features
974    /// in the model.
975    /// - `method`: variable importance method to use.
976    /// - `n_features`: The number of features to calculate the importance for.
977    pub fn calculate_feature_importance(
978        &self,
979        method: ImportanceMethod,
980        normalize: bool,
981    ) -> HashMap<usize, f32> {
982        let (average, importance_fn): (bool, ImportanceFn) = match method {
983            ImportanceMethod::Weight => (false, Tree::calculate_importance_weight),
984            ImportanceMethod::Gain => (true, Tree::calculate_importance_gain),
985            ImportanceMethod::TotalGain => (false, Tree::calculate_importance_gain),
986            ImportanceMethod::Cover => (true, Tree::calculate_importance_cover),
987            ImportanceMethod::TotalCover => (false, Tree::calculate_importance_cover),
988        };
989        let mut stats = HashMap::new();
990        for tree in self.trees.iter() {
991            importance_fn(tree, &mut stats)
992        }
993
994        let importance = stats
995            .iter()
996            .map(|(k, (v, c))| {
997                if average {
998                    (*k, v / (*c as f32))
999                } else {
1000                    (*k, *v)
1001                }
1002            })
1003            .collect::<HashMap<usize, f32>>();
1004
1005        if normalize {
1006            // To make deterministic, sort values and then sum.
1007            // Otherwise we were getting them in different orders, and
1008            // floating point error was creeping in.
1009            let mut values: Vec<f32> = importance.values().copied().collect();
1010            // We are OK to unwrap because we know we will never have missing.
1011            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1012            let total: f32 = values.iter().sum();
1013            importance.iter().map(|(k, v)| (*k, v / total)).collect()
1014        } else {
1015            importance
1016        }
1017    }
1018
1019    /// Save a booster as a json object to a file.
1020    ///
1021    /// * `path` - Path to save booster.
1022    pub fn save_booster(&self, path: &str) -> Result<(), ForustError> {
1023        let model = self.json_dump()?;
1024        match fs::write(path, model) {
1025            Err(e) => Err(ForustError::UnableToWrite(e.to_string())),
1026            Ok(_) => Ok(()),
1027        }
1028    }
1029
1030    /// Dump a booster as a json object
1031    pub fn json_dump(&self) -> Result<String, ForustError> {
1032        match serde_json::to_string(self) {
1033            Ok(s) => Ok(s),
1034            Err(e) => Err(ForustError::UnableToWrite(e.to_string())),
1035        }
1036    }
1037
1038    /// Load a booster from Json string
1039    ///
1040    /// * `json_str` - String object, which can be serialized to json.
1041    pub fn from_json(json_str: &str) -> Result<Self, ForustError> {
1042        let model = serde_json::from_str::<GradientBooster>(json_str);
1043        match model {
1044            Ok(m) => Ok(m),
1045            Err(e) => Err(ForustError::UnableToRead(e.to_string())),
1046        }
1047    }
1048
1049    /// Load a booster from a path to a json booster object.
1050    ///
1051    /// * `path` - Path to load booster from.
1052    pub fn load_booster(path: &str) -> Result<Self, ForustError> {
1053        let json_str = match fs::read_to_string(path) {
1054            Ok(s) => Ok(s),
1055            Err(e) => Err(ForustError::UnableToRead(e.to_string())),
1056        }?;
1057        Self::from_json(&json_str)
1058    }
1059
1060    // Set methods for paramters
1061    /// Set the objective_type on the booster.
1062    /// * `objective_type` - The objective type of the booster.
1063    pub fn set_objective_type(mut self, objective_type: ObjectiveType) -> Self {
1064        self.objective_type = objective_type;
1065        self
1066    }
1067
1068    /// Set the iterations on the booster.
1069    /// * `iterations` - The number of iterations of the booster.
1070    pub fn set_iterations(mut self, iterations: usize) -> Self {
1071        self.iterations = iterations;
1072        self
1073    }
1074
1075    /// Set the learning_rate on the booster.
1076    /// * `learning_rate` - The learning rate of the booster.
1077    pub fn set_learning_rate(mut self, learning_rate: f32) -> Self {
1078        self.learning_rate = learning_rate;
1079        self
1080    }
1081
1082    /// Set the max_depth on the booster.
1083    /// * `max_depth` - The maximum tree depth of the booster.
1084    pub fn set_max_depth(mut self, max_depth: usize) -> Self {
1085        self.max_depth = max_depth;
1086        self
1087    }
1088
1089    /// Set the max_leaves on the booster.
1090    /// * `max_leaves` - The maximum number of leaves of the booster.
1091    pub fn set_max_leaves(mut self, max_leaves: usize) -> Self {
1092        self.max_leaves = max_leaves;
1093        self
1094    }
1095
1096    /// Set the number of nbins on the booster.
1097    /// * `max_leaves` - Number of bins to calculate to partition the data. Setting this to
1098    ///   a smaller number, will result in faster training time, while potentially sacrificing
1099    ///   accuracy. If there are more bins, than unique values in a column, all unique values
1100    ///   will be used.
1101    pub fn set_nbins(mut self, nbins: u16) -> Self {
1102        self.nbins = nbins;
1103        self
1104    }
1105
1106    /// Set the l1 on the booster.
1107    /// * `l1` - The l1 regulation term of the booster.
1108    pub fn set_l1(mut self, l1: f32) -> Self {
1109        self.l1 = l1;
1110        self
1111    }
1112
1113    /// Set the l2 on the booster.
1114    /// * `l2` - The l2 regulation term of the booster.
1115    pub fn set_l2(mut self, l2: f32) -> Self {
1116        self.l2 = l2;
1117        self
1118    }
1119
1120    /// Set the gamma on the booster.
1121    /// * `gamma` - The gamma value of the booster.
1122    pub fn set_gamma(mut self, gamma: f32) -> Self {
1123        self.gamma = gamma;
1124        self
1125    }
1126
1127    /// Set the max_delta_step on the booster.
1128    /// * `max_delta_step` - The max_delta_step value of the booster.
1129    pub fn set_max_delta_step(mut self, max_delta_step: f32) -> Self {
1130        self.max_delta_step = max_delta_step;
1131        self
1132    }
1133
1134    /// Set the min_leaf_weight on the booster.
1135    /// * `min_leaf_weight` - The minimum sum of the hession values allowed in the
1136    ///     node of a tree of the booster.
1137    pub fn set_min_leaf_weight(mut self, min_leaf_weight: f32) -> Self {
1138        self.min_leaf_weight = min_leaf_weight;
1139        self
1140    }
1141
1142    /// Set the base_score on the booster.
1143    /// * `base_score` - The base score of the booster.
1144    pub fn set_base_score(mut self, base_score: f64) -> Self {
1145        self.base_score = base_score;
1146        self
1147    }
1148
1149    /// Set the base_score on the booster.
1150    /// * `base_score` - The base score of the booster.
1151    pub fn set_initialize_base_score(mut self, initialize_base_score: bool) -> Self {
1152        self.initialize_base_score = initialize_base_score;
1153        self
1154    }
1155
1156    /// Set the parallel on the booster.
1157    /// * `parallel` - Set if the booster should be trained in parallels.
1158    pub fn set_parallel(mut self, parallel: bool) -> Self {
1159        self.parallel = parallel;
1160        self
1161    }
1162
1163    /// Set the allow_missing_splits on the booster.
1164    /// * `allow_missing_splits` - Set if missing splits are allowed for the booster.
1165    pub fn set_allow_missing_splits(mut self, allow_missing_splits: bool) -> Self {
1166        self.allow_missing_splits = allow_missing_splits;
1167        self
1168    }
1169
1170    /// Set the monotone_constraints on the booster.
1171    /// * `monotone_constraints` - The monotone constraints of the booster.
1172    pub fn set_monotone_constraints(mut self, monotone_constraints: Option<ConstraintMap>) -> Self {
1173        self.monotone_constraints = monotone_constraints;
1174        self
1175    }
1176
1177    /// Set the subsample on the booster.
1178    /// * `subsample` - Percent of the data to randomly sample when training each tree.
1179    pub fn set_subsample(mut self, subsample: f32) -> Self {
1180        self.subsample = subsample;
1181        self
1182    }
1183
1184    /// Set the colsample_bytree on the booster.
1185    /// * `colsample_bytree` - Percent of the columns to randomly sample when training each tree.
1186    pub fn set_colsample_bytree(mut self, colsample_bytree: f64) -> Self {
1187        self.colsample_bytree = colsample_bytree;
1188        self
1189    }
1190
1191    /// Set the seed on the booster.
1192    /// * `seed` - Integer value used to see any randomness used in the algorithm.
1193    pub fn set_seed(mut self, seed: u64) -> Self {
1194        self.seed = seed;
1195        self
1196    }
1197
1198    /// Set missing value of the booster
1199    /// * `missing` - Float value to consider as missing.
1200    pub fn set_missing(mut self, missing: f64) -> Self {
1201        self.missing = missing;
1202        self
1203    }
1204
1205    /// Set create missing value of the booster
1206    /// * `create_missing_branch` - Bool specifying if missing should get it's own
1207    /// branch.
1208    pub fn set_create_missing_branch(mut self, create_missing_branch: bool) -> Self {
1209        self.create_missing_branch = create_missing_branch;
1210        self
1211    }
1212
1213    /// Set sample method on the booster.
1214    /// * `sample_method` - Sample method.
1215    pub fn set_sample_method(mut self, sample_method: SampleMethod) -> Self {
1216        self.sample_method = sample_method;
1217        self
1218    }
1219
1220    /// Set sample method on the booster.
1221    /// * `evaluation_metric` - Sample method.
1222    pub fn set_evaluation_metric(mut self, evaluation_metric: Option<Metric>) -> Self {
1223        self.evaluation_metric = evaluation_metric;
1224        self
1225    }
1226
1227    /// Set early stopping rounds.
1228    /// * `early_stopping_rounds` - Early stoppings rounds.
1229    pub fn set_early_stopping_rounds(mut self, early_stopping_rounds: Option<usize>) -> Self {
1230        self.early_stopping_rounds = early_stopping_rounds;
1231        self
1232    }
1233
1234    /// Set prediction iterations.
1235    /// * `early_stopping_rounds` - Early stoppings rounds.
1236    pub fn set_prediction_iteration(mut self, prediction_iteration: Option<usize>) -> Self {
1237        self.prediction_iteration = prediction_iteration.map(|i| i + 1);
1238        self
1239    }
1240
1241    /// Set the features where whose missing nodes should
1242    /// always be terminated.
1243    /// * `terminate_missing_features` - Hashset of the feature indices for the
1244    /// features that should always terminate the missing node, if create_missing_branch
1245    /// is true.
1246    pub fn set_terminate_missing_features(
1247        mut self,
1248        terminate_missing_features: HashSet<usize>,
1249    ) -> Self {
1250        self.terminate_missing_features = terminate_missing_features;
1251        self
1252    }
1253
1254    /// Insert metadata
1255    /// * `key` - String value for the metadata key.
1256    /// * `value` - value to assign to the metadata key.
1257    pub fn insert_metadata(&mut self, key: String, value: String) {
1258        self.metadata.insert(key, value);
1259    }
1260
1261    /// Get Metadata
1262    /// * `key` - Get the associated value for the metadata key.
1263    pub fn get_metadata(&self, key: &String) -> Option<String> {
1264        self.metadata.get(key).cloned()
1265    }
1266}
1267
1268#[cfg(test)]
1269mod tests {
1270    use super::*;
1271    use std::fs;
1272
1273    #[test]
1274    fn test_booster_fit_subsample() {
1275        let file = fs::read_to_string("resources/contiguous_with_missing.csv")
1276            .expect("Something went wrong reading the file");
1277        let data_vec: Vec<f64> = file
1278            .lines()
1279            .map(|x| x.parse::<f64>().unwrap_or(f64::NAN))
1280            .collect();
1281        let file = fs::read_to_string("resources/performance.csv")
1282            .expect("Something went wrong reading the file");
1283        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1284
1285        let data = Matrix::new(&data_vec, 891, 5);
1286        //let data = Matrix::new(data.get_col(1), 891, 1);
1287        let mut booster = GradientBooster::default()
1288            .set_iterations(10)
1289            .set_nbins(300)
1290            .set_max_depth(3)
1291            .set_subsample(0.5)
1292            .set_base_score(0.5)
1293            .set_initialize_base_score(false);
1294        let sample_weight = vec![1.; y.len()];
1295        booster.fit(&data, &y, &sample_weight, None).unwrap();
1296        let preds = booster.predict(&data, false);
1297        let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
1298        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
1299        println!("{}", booster.trees[0]);
1300        println!("{}", booster.trees[0].nodes.len());
1301        println!("{}", booster.trees.last().unwrap().nodes.len());
1302        println!("{:?}", &preds[0..10]);
1303    }
1304
1305    #[test]
1306    fn test_booster_fit() {
1307        let file = fs::read_to_string("resources/contiguous_with_missing.csv")
1308            .expect("Something went wrong reading the file");
1309        let data_vec: Vec<f64> = file
1310            .lines()
1311            .map(|x| x.parse::<f64>().unwrap_or(f64::NAN))
1312            .collect();
1313        let file = fs::read_to_string("resources/performance.csv")
1314            .expect("Something went wrong reading the file");
1315        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1316
1317        let data = Matrix::new(&data_vec, 891, 5);
1318        //let data = Matrix::new(data.get_col(1), 891, 1);
1319        let mut booster = GradientBooster::default()
1320            .set_iterations(10)
1321            .set_nbins(300)
1322            .set_max_depth(3)
1323            .set_base_score(0.5)
1324            .set_initialize_base_score(false);
1325
1326        let sample_weight = vec![1.; y.len()];
1327        booster.fit(&data, &y, &sample_weight, None).unwrap();
1328        let preds = booster.predict(&data, false);
1329        let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
1330        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
1331        println!("{}", booster.trees[0]);
1332        println!("{}", booster.trees[0].nodes.len());
1333        println!("{}", booster.trees.last().unwrap().nodes.len());
1334        println!("{:?}", &preds[0..10]);
1335    }
1336
1337    #[test]
1338    fn test_booster_fit_nofitted_base_score() {
1339        let file = fs::read_to_string("resources/contiguous_with_missing.csv")
1340            .expect("Something went wrong reading the file");
1341        let data_vec: Vec<f64> = file
1342            .lines()
1343            .map(|x| x.parse::<f64>().unwrap_or(f64::NAN))
1344            .collect();
1345        let file = fs::read_to_string("resources/performance-fare.csv")
1346            .expect("Something went wrong reading the file");
1347        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1348
1349        let data = Matrix::new(&data_vec, 891, 5);
1350        //let data = Matrix::new(data.get_col(1), 891, 1);
1351        let mut booster = GradientBooster::default()
1352            .set_objective_type(ObjectiveType::SquaredLoss)
1353            .set_iterations(10)
1354            .set_nbins(300)
1355            .set_max_depth(3)
1356            .set_initialize_base_score(true);
1357        let sample_weight = vec![1.; y.len()];
1358        booster.fit(&data, &y, &sample_weight, None).unwrap();
1359        let preds = booster.predict(&data, false);
1360        let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
1361        assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
1362        println!("{}", booster.trees[0]);
1363        println!("{}", booster.trees[0].nodes.len());
1364        println!("{}", booster.trees.last().unwrap().nodes.len());
1365        println!("{:?}", &preds[0..10]);
1366    }
1367
1368    #[test]
1369    fn test_tree_save() {
1370        let file = fs::read_to_string("resources/contiguous_with_missing.csv")
1371            .expect("Something went wrong reading the file");
1372        let data_vec: Vec<f64> = file
1373            .lines()
1374            .map(|x| x.parse::<f64>().unwrap_or(f64::NAN))
1375            .collect();
1376        let file = fs::read_to_string("resources/performance.csv")
1377            .expect("Something went wrong reading the file");
1378        let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1379
1380        let data = Matrix::new(&data_vec, 891, 5);
1381        //let data = Matrix::new(data.get_col(1), 891, 1);
1382        let mut booster = GradientBooster::default()
1383            .set_iterations(10)
1384            .set_nbins(300)
1385            .set_max_depth(3)
1386            .set_base_score(0.5)
1387            .set_initialize_base_score(false);
1388        let sample_weight = vec![1.; y.len()];
1389        booster.fit(&data, &y, &sample_weight, None).unwrap();
1390        let preds = booster.predict(&data, true);
1391
1392        booster.save_booster("resources/model64.json").unwrap();
1393        let booster2 = GradientBooster::load_booster("resources/model64.json").unwrap();
1394        assert_eq!(booster2.predict(&data, true)[0..10], preds[0..10]);
1395
1396        // Test with non-NAN missing.
1397        booster.missing = 0.;
1398        booster.save_booster("resources/modelmissing.json").unwrap();
1399        let booster3 = GradientBooster::load_booster("resources/modelmissing.json").unwrap();
1400        assert_eq!(booster3.missing, 0.);
1401        assert_eq!(booster3.missing, booster.missing);
1402    }
1403}