1use crate::binning::bin_matrix;
2use crate::constraints::ConstraintMap;
3use crate::data::{Matrix, RowMajorMatrix};
4use crate::errors::ForustError;
5use crate::metric::{is_comparison_better, metric_callables, Metric, MetricFn};
6use crate::objective::{
7 calc_init_callables, gradient_hessian_callables, LogLoss, ObjectiveFunction, ObjectiveType,
8 SquaredLoss,
9};
10use crate::sampler::{GossSampler, RandomSampler, SampleMethod, Sampler};
11use crate::shapley::predict_contributions_row_shapley;
12use crate::splitter::{MissingBranchSplitter, MissingImputerSplitter, Splitter};
13use crate::tree::Tree;
14use crate::utils::{
15 fmt_vec_output, odds, validate_not_nan_vec, validate_positive_float_field,
16 validate_positive_not_nan_vec,
17};
18use log::info;
19use rand::rngs::StdRng;
20use rand::seq::IteratorRandom;
21use rand::SeedableRng;
22use rayon::prelude::*;
23use serde::{Deserialize, Deserializer, Serialize};
24use std::collections::{HashMap, HashSet};
25use std::fs;
26
27pub type EvaluationData<'a> = (Matrix<'a, f64>, &'a [f64], &'a [f64]);
28pub type TrainingEvaluationData<'a> = (&'a Matrix<'a, f64>, &'a [f64], &'a [f64], Vec<f64>);
29type ImportanceFn = fn(&Tree, &mut HashMap<usize, (f32, usize)>);
30
31#[derive(Serialize, Deserialize)]
32pub enum GrowPolicy {
33 DepthWise,
34 LossGuide,
35}
36
37#[derive(Serialize, Deserialize)]
38pub enum ContributionsMethod {
39 Weight,
41 Average,
43 BranchDifference,
45 MidpointDifference,
47 ModeDifference,
49 ProbabilityChange,
51 Shapley,
53}
54
55#[derive(Serialize, Deserialize)]
57pub enum ImportanceMethod {
58 Weight,
60 Gain,
62 Cover,
64 TotalGain,
66 TotalCover,
68}
69
70#[derive(Serialize, Deserialize, Clone, Copy)]
71pub enum MissingNodeTreatment {
72 None,
74 AssignToParent,
76 AverageLeafWeight,
78 AverageNodeWeight,
80}
81
82#[derive(Deserialize, Serialize)]
84pub struct GradientBooster {
85 pub objective_type: ObjectiveType,
89 pub iterations: usize,
91 pub learning_rate: f32,
95 pub max_depth: usize,
97 pub max_leaves: usize,
100 #[serde(default = "default_l1")]
103 pub l1: f32,
104 pub l2: f32,
107 pub gamma: f32,
110 #[serde(default = "default_max_delta_step")]
112 pub max_delta_step: f32,
113 pub min_leaf_weight: f32,
116 pub base_score: f64,
118 pub nbins: u16,
123 pub parallel: bool,
124 pub allow_missing_splits: bool,
128 pub monotone_constraints: Option<ConstraintMap>,
131 pub subsample: f32,
133 #[serde(default = "default_top_rate")]
135 pub top_rate: f64,
136 #[serde(default = "default_other_rate")]
138 pub other_rate: f64,
139 #[serde(default = "default_colsample_bytree")]
141 pub colsample_bytree: f64,
142 pub seed: u64,
144 #[serde(deserialize_with = "parse_missing")]
146 pub missing: f64,
147 pub create_missing_branch: bool,
149 #[serde(default = "default_sample_method")]
151 pub sample_method: SampleMethod,
152 #[serde(default = "default_grow_policy")]
154 pub grow_policy: GrowPolicy,
155 #[serde(default = "default_evaluation_metric")]
157 pub evaluation_metric: Option<Metric>,
158 #[serde(default = "default_early_stopping_rounds")]
161 pub early_stopping_rounds: Option<usize>,
162 #[serde(default = "default_initialize_base_score")]
164 pub initialize_base_score: bool,
165 #[serde(default = "default_terminate_missing_features")]
169 pub terminate_missing_features: HashSet<usize>,
170 #[serde(default = "default_evaluation_history")]
172 pub evaluation_history: Option<RowMajorMatrix<f64>>,
173 #[serde(default = "default_best_iteration")]
174 pub best_iteration: Option<usize>,
175 #[serde(default = "default_prediction_iteration")]
178 pub prediction_iteration: Option<usize>,
179 #[serde(default = "default_missing_node_treatment")]
181 pub missing_node_treatment: MissingNodeTreatment,
182 #[serde(default = "default_log_iterations")]
184 pub log_iterations: usize,
185 #[serde(default = "default_force_children_to_bound_parent")]
187 pub force_children_to_bound_parent: bool,
188 pub trees: Vec<Tree>,
191 metadata: HashMap<String, String>,
192}
193
194fn default_l1() -> f32 {
195 0.0
196}
197fn default_max_delta_step() -> f32 {
198 0.0
199}
200
201fn default_initialize_base_score() -> bool {
202 false
203}
204
205fn default_grow_policy() -> GrowPolicy {
206 GrowPolicy::DepthWise
207}
208
209fn default_top_rate() -> f64 {
210 0.1
211}
212fn default_other_rate() -> f64 {
213 0.2
214}
215fn default_sample_method() -> SampleMethod {
216 SampleMethod::None
217}
218fn default_evaluation_metric() -> Option<Metric> {
219 None
220}
221fn default_early_stopping_rounds() -> Option<usize> {
222 None
223}
224fn default_evaluation_history() -> Option<RowMajorMatrix<f64>> {
225 None
226}
227fn default_best_iteration() -> Option<usize> {
228 None
229}
230fn default_prediction_iteration() -> Option<usize> {
231 None
232}
233fn default_terminate_missing_features() -> HashSet<usize> {
234 HashSet::new()
235}
236fn default_colsample_bytree() -> f64 {
237 1.0
238}
239fn default_missing_node_treatment() -> MissingNodeTreatment {
240 MissingNodeTreatment::AssignToParent
241}
242
243fn default_log_iterations() -> usize {
244 0
245}
246fn default_force_children_to_bound_parent() -> bool {
247 false
248}
249
250fn parse_missing<'de, D>(d: D) -> Result<f64, D::Error>
251where
252 D: Deserializer<'de>,
253{
254 Deserialize::deserialize(d).map(|x: Option<_>| x.unwrap_or(f64::NAN))
255}
256
257impl Default for GradientBooster {
258 fn default() -> Self {
259 Self::new(
260 ObjectiveType::LogLoss,
261 10,
262 0.3,
263 5,
264 usize::MAX,
265 0.,
266 1.,
267 0.,
268 0.,
269 1.,
270 0.5,
271 256,
272 true,
273 true,
274 None,
275 1.,
276 0.1,
277 0.2,
278 1.0,
279 0,
280 f64::NAN,
281 false,
282 SampleMethod::None,
283 GrowPolicy::DepthWise,
284 None,
285 None,
286 true,
287 HashSet::new(),
288 MissingNodeTreatment::AssignToParent,
289 0,
290 false,
291 )
292 .unwrap()
293 }
294}
295
296impl GradientBooster {
297 #[allow(clippy::too_many_arguments)]
341 pub fn new(
342 objective_type: ObjectiveType,
343 iterations: usize,
344 learning_rate: f32,
345 max_depth: usize,
346 max_leaves: usize,
347 l1: f32,
348 l2: f32,
349 gamma: f32,
350 max_delta_step: f32,
351 min_leaf_weight: f32,
352 base_score: f64,
353 nbins: u16,
354 parallel: bool,
355 allow_missing_splits: bool,
356 monotone_constraints: Option<ConstraintMap>,
357 subsample: f32,
358 top_rate: f64,
359 other_rate: f64,
360 colsample_bytree: f64,
361 seed: u64,
362 missing: f64,
363 create_missing_branch: bool,
364 sample_method: SampleMethod,
365 grow_policy: GrowPolicy,
366 evaluation_metric: Option<Metric>,
367 early_stopping_rounds: Option<usize>,
368 initialize_base_score: bool,
369 terminate_missing_features: HashSet<usize>,
370 missing_node_treatment: MissingNodeTreatment,
371 log_iterations: usize,
372 force_children_to_bound_parent: bool,
373 ) -> Result<Self, ForustError> {
374 let booster = GradientBooster {
375 objective_type,
376 iterations,
377 learning_rate,
378 max_depth,
379 max_leaves,
380 l1,
381 l2,
382 gamma,
383 max_delta_step,
384 min_leaf_weight,
385 base_score,
386 nbins,
387 parallel,
388 allow_missing_splits,
389 monotone_constraints,
390 subsample,
391 top_rate,
392 other_rate,
393 colsample_bytree,
394 seed,
395 missing,
396 create_missing_branch,
397 sample_method,
398 grow_policy,
399 evaluation_metric,
400 early_stopping_rounds,
401 initialize_base_score,
402 terminate_missing_features,
403 evaluation_history: None,
404 best_iteration: None,
405 prediction_iteration: None,
406 missing_node_treatment,
407 log_iterations,
408 force_children_to_bound_parent,
409 trees: Vec::new(),
410 metadata: HashMap::new(),
411 };
412 booster.validate_parameters()?;
413 Ok(booster)
414 }
415
416 fn validate_parameters(&self) -> Result<(), ForustError> {
417 validate_positive_float_field!(self.learning_rate);
418 validate_positive_float_field!(self.l1);
419 validate_positive_float_field!(self.l2);
420 validate_positive_float_field!(self.gamma);
421 validate_positive_float_field!(self.max_delta_step);
422 validate_positive_float_field!(self.min_leaf_weight);
423 validate_positive_float_field!(self.subsample);
424 validate_positive_float_field!(self.top_rate);
425 validate_positive_float_field!(self.other_rate);
426 validate_positive_float_field!(self.colsample_bytree);
427 Ok(())
428 }
429
430 pub fn fit(
437 &mut self,
438 data: &Matrix<f64>,
439 y: &[f64],
440 sample_weight: &[f64],
441 evaluation_data: Option<Vec<EvaluationData>>,
442 ) -> Result<(), ForustError> {
443 validate_not_nan_vec(y, "y".to_string())?;
445 validate_positive_not_nan_vec(sample_weight, "sample_weight".to_string())?;
446 if let Some(eval_data) = &evaluation_data {
447 for (i, (_, eval_y, eval_sample_weight)) in eval_data.iter().enumerate() {
448 validate_not_nan_vec(eval_y, format!("eval set {} y", i).to_string())?;
449 validate_positive_not_nan_vec(
450 eval_sample_weight,
451 format!("eval set {} sample_weight", i).to_string(),
452 )?;
453 }
454 }
455
456 let constraints_map = self
457 .monotone_constraints
458 .as_ref()
459 .unwrap_or(&ConstraintMap::new())
460 .to_owned();
461 if self.create_missing_branch {
462 let splitter = MissingBranchSplitter {
463 l1: self.l1,
464 l2: self.l2,
465 max_delta_step: self.max_delta_step,
466 gamma: self.gamma,
467 min_leaf_weight: self.min_leaf_weight,
468 learning_rate: self.learning_rate,
469 allow_missing_splits: self.allow_missing_splits,
470 constraints_map,
471 terminate_missing_features: self.terminate_missing_features.clone(),
472 missing_node_treatment: self.missing_node_treatment,
473 force_children_to_bound_parent: self.force_children_to_bound_parent,
474 };
475 self.fit_trees(y, sample_weight, data, &splitter, evaluation_data)?;
476 } else {
477 let splitter = MissingImputerSplitter {
478 l1: self.l1,
479 l2: self.l2,
480 max_delta_step: self.max_delta_step,
481 gamma: self.gamma,
482 min_leaf_weight: self.min_leaf_weight,
483 learning_rate: self.learning_rate,
484 allow_missing_splits: self.allow_missing_splits,
485 constraints_map,
486 };
487 self.fit_trees(y, sample_weight, data, &splitter, evaluation_data)?;
488 };
489
490 Ok(())
491 }
492
493 fn sample_index(
494 &self,
495 rng: &mut StdRng,
496 index: &[usize],
497 grad: &mut [f32],
498 hess: &mut [f32],
499 ) -> (Vec<usize>, Vec<usize>) {
500 match self.sample_method {
501 SampleMethod::None => (index.to_owned(), Vec::new()),
502 SampleMethod::Random => {
503 RandomSampler::new(self.subsample).sample(rng, index, grad, hess)
504 }
505 SampleMethod::Goss => {
506 GossSampler::new(self.top_rate, self.other_rate).sample(rng, index, grad, hess)
507 }
508 }
509 }
510
511 fn get_metric_fn(&self) -> (MetricFn, bool) {
512 let metric = match &self.evaluation_metric {
513 None => match self.objective_type {
514 ObjectiveType::LogLoss => LogLoss::default_metric(),
515 ObjectiveType::SquaredLoss => SquaredLoss::default_metric(),
516 },
517 Some(v) => *v,
518 };
519 metric_callables(&metric)
520 }
521
522 fn reset(&mut self) {
523 self.trees = Vec::new();
524 self.evaluation_history = None;
525 self.best_iteration = None;
526 self.prediction_iteration = None;
527 }
528
529 fn fit_trees<T: Splitter>(
530 &mut self,
531 y: &[f64],
532 sample_weight: &[f64],
533 data: &Matrix<f64>,
534 splitter: &T,
535 evaluation_data: Option<Vec<EvaluationData>>,
536 ) -> Result<(), ForustError> {
537 if !self.trees.is_empty() {
540 self.reset()
541 }
542
543 let mut rng = StdRng::seed_from_u64(self.seed);
544
545 if self.initialize_base_score {
546 self.base_score = calc_init_callables(&self.objective_type)(y, sample_weight);
547 }
548
549 let mut yhat = vec![self.base_score; y.len()];
550
551 let calc_grad_hess = gradient_hessian_callables(&self.objective_type);
552 let (mut grad, mut hess) = calc_grad_hess(y, &yhat, sample_weight);
553
554 let binned_data = bin_matrix(data, sample_weight, self.nbins, self.missing)?;
559 let bdata = Matrix::new(&binned_data.binned_data, data.rows, data.cols);
560
561 let mut evaluation_sets: Option<Vec<TrainingEvaluationData>> =
563 evaluation_data.as_ref().map(|evals| {
564 evals
565 .iter()
566 .map(|(d, y, w)| (d, *y, *w, vec![self.base_score; y.len()]))
567 .collect()
568 });
569
570 let mut best_metric: Option<f64> = None;
571
572 let mut stop_early = false;
574 let col_index: Vec<usize> = (0..data.cols).collect();
575 for i in 0..self.iterations {
576 let verbose = if self.log_iterations == 0 {
577 false
578 } else {
579 i % self.log_iterations == 0
580 };
581 let (chosen_index, _excluded_index) =
583 self.sample_index(&mut rng, &data.index, &mut grad, &mut hess);
584 let mut tree = Tree::new();
585
586 let colsample_index: Vec<usize> = if self.colsample_bytree == 1.0 {
588 Vec::new()
589 } else {
590 let amount = ((col_index.len() as f64) * self.colsample_bytree).floor() as usize;
591 let mut v: Vec<usize> = col_index
592 .iter()
593 .choose_multiple(&mut rng, amount)
594 .iter()
595 .map(|i| **i)
596 .collect();
597 v.sort();
598 v
599 };
600
601 let fit_col_index = if self.colsample_bytree == 1.0 {
602 &col_index
603 } else {
604 &colsample_index
605 };
606
607 tree.fit(
608 &bdata,
609 chosen_index,
610 fit_col_index,
611 &binned_data.cuts,
612 &grad,
613 &hess,
614 splitter,
615 self.max_leaves,
616 self.max_depth,
617 self.parallel,
618 &self.sample_method,
619 &self.grow_policy,
620 );
621
622 self.update_predictions_inplace(&mut yhat, &tree, data);
623
624 if let Some(eval_sets) = &mut evaluation_sets {
626 if self.evaluation_history.is_none() {
627 self.evaluation_history =
628 Some(RowMajorMatrix::new(Vec::new(), 0, eval_sets.len()));
629 }
630 let mut metrics: Vec<f64> = Vec::new();
631 let n_eval_sets = eval_sets.len();
632 for (eval_i, (data, y, w, yhat)) in eval_sets.iter_mut().enumerate() {
633 self.update_predictions_inplace(yhat, &tree, data);
634 let (metric_fn, maximize) = self.get_metric_fn();
635 let m = metric_fn(y, yhat, w);
636 if (eval_i + 1) == n_eval_sets {
641 if let Some(early_stopping_rounds) = self.early_stopping_rounds {
642 best_metric = match best_metric {
645 None => {
646 self.update_best_iteration(i);
647 Some(m)
648 }
649 Some(v) => {
651 if is_comparison_better(v, m, maximize) {
653 self.update_best_iteration(i);
654 Some(m)
655 } else {
656 if let Some(best_iteration) = self.best_iteration {
658 if i - best_iteration >= early_stopping_rounds {
659 if self.log_iterations > 0 {
661 info!("Stopping early at iteration {} with metric value {}", i, m)
662 }
663 stop_early = true;
664 }
665 }
666 Some(v)
667 }
668 }
669 };
670 }
671 }
672 metrics.push(m);
673 }
674 if verbose {
675 info!(
676 "Iteration {} evaluation data values: {}",
677 i,
678 fmt_vec_output(&metrics)
679 );
680 }
681 if let Some(history) = &mut self.evaluation_history {
682 history.append_row(metrics);
683 }
684 }
685 self.trees.push(tree);
686
687 if stop_early {
689 break;
690 }
691
692 (grad, hess) = calc_grad_hess(y, &yhat, sample_weight);
693 if verbose {
694 info!("Completed iteration {} of {}", i, self.iterations);
695 }
696 }
697 if self.log_iterations > 0 {
698 info!(
699 "Finished training booster with {} iterations.",
700 self.trees.len()
701 );
702 }
703 Ok(())
704 }
705
706 fn update_best_iteration(&mut self, i: usize) {
707 self.best_iteration = Some(i);
708 self.prediction_iteration = Some(i + 1);
709 }
710
711 fn update_predictions_inplace(&self, yhat: &mut [f64], tree: &Tree, data: &Matrix<f64>) {
712 let preds = tree.predict(data, self.parallel, &self.missing);
713 yhat.iter_mut().zip(preds).for_each(|(i, j)| *i += j);
714 }
715
716 pub fn fit_unweighted(
721 &mut self,
722 data: &Matrix<f64>,
723 y: &[f64],
724 evaluation_data: Option<Vec<EvaluationData>>,
725 ) -> Result<(), ForustError> {
726 let sample_weight = vec![1.0; data.rows];
727 self.fit(data, y, &sample_weight, evaluation_data)
728 }
729
730 pub fn predict(&self, data: &Matrix<f64>, parallel: bool) -> Vec<f64> {
734 let mut init_preds = vec![self.base_score; data.rows];
735 self.get_prediction_trees().iter().for_each(|tree| {
736 for (p_, val) in init_preds
737 .iter_mut()
738 .zip(tree.predict(data, parallel, &self.missing))
739 {
740 *p_ += val;
741 }
742 });
743 init_preds
744 }
745
746 pub fn predict_leaf_indices(&self, data: &Matrix<f64>) -> Vec<usize> {
748 self.get_prediction_trees()
749 .iter()
750 .flat_map(|tree| tree.predict_leaf_indices(data, &self.missing))
751 .collect()
752 }
753
754 pub fn predict_contributions(
756 &self,
757 data: &Matrix<f64>,
758 method: ContributionsMethod,
759 parallel: bool,
760 ) -> Vec<f64> {
761 match method {
762 ContributionsMethod::Average => self.predict_contributions_average(data, parallel),
763 ContributionsMethod::ProbabilityChange => {
764 match self.objective_type {
765 ObjectiveType::LogLoss => {},
766 _ => panic!("ProbabilityChange contributions method is only valid when LogLoss objective is used.")
767 }
768 self.predict_contributions_probability_change(data, parallel)
769 }
770 _ => self.predict_contributions_tree_alone(data, parallel, method),
771 }
772 }
773
774 fn predict_contributions_tree_alone(
778 &self,
779 data: &Matrix<f64>,
780 parallel: bool,
781 method: ContributionsMethod,
782 ) -> Vec<f64> {
783 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
784
785 let bias_idx = data.cols + 1;
787 contribs
788 .iter_mut()
789 .skip(bias_idx - 1)
790 .step_by(bias_idx)
791 .for_each(|v| *v += self.base_score);
792
793 let row_pred_fn = match method {
794 ContributionsMethod::Weight => Tree::predict_contributions_row_weight,
795 ContributionsMethod::BranchDifference => {
796 Tree::predict_contributions_row_branch_difference
797 }
798 ContributionsMethod::MidpointDifference => {
799 Tree::predict_contributions_row_midpoint_difference
800 }
801 ContributionsMethod::ModeDifference => Tree::predict_contributions_row_mode_difference,
802 ContributionsMethod::Shapley => predict_contributions_row_shapley,
803 ContributionsMethod::Average | ContributionsMethod::ProbabilityChange => unreachable!(),
804 };
805 if parallel {
810 data.index
811 .par_iter()
812 .zip(contribs.par_chunks_mut(data.cols + 1))
813 .for_each(|(row, c)| {
814 let r_ = data.get_row(*row);
815 self.get_prediction_trees().iter().for_each(|t| {
816 row_pred_fn(t, &r_, c, &self.missing);
817 });
818 });
819 } else {
820 data.index
821 .iter()
822 .zip(contribs.chunks_mut(data.cols + 1))
823 .for_each(|(row, c)| {
824 let r_ = data.get_row(*row);
825 self.get_prediction_trees().iter().for_each(|t| {
826 row_pred_fn(t, &r_, c, &self.missing);
827 });
828 });
829 }
830
831 contribs
832 }
833
834 fn get_prediction_trees(&self) -> &[Tree] {
837 let n_iterations = self.prediction_iteration.unwrap_or(self.trees.len());
838 &self.trees[..n_iterations]
839 }
840
841 fn predict_contributions_average(&self, data: &Matrix<f64>, parallel: bool) -> Vec<f64> {
846 let weights: Vec<Vec<f64>> = if parallel {
847 self.get_prediction_trees()
848 .par_iter()
849 .map(|t| t.distribute_leaf_weights())
850 .collect()
851 } else {
852 self.get_prediction_trees()
853 .iter()
854 .map(|t| t.distribute_leaf_weights())
855 .collect()
856 };
857 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
858
859 let bias_idx = data.cols + 1;
861 contribs
862 .iter_mut()
863 .skip(bias_idx - 1)
864 .step_by(bias_idx)
865 .for_each(|v| *v += self.base_score);
866
867 if parallel {
872 data.index
873 .par_iter()
874 .zip(contribs.par_chunks_mut(data.cols + 1))
875 .for_each(|(row, c)| {
876 let r_ = data.get_row(*row);
877 self.get_prediction_trees()
878 .iter()
879 .zip(weights.iter())
880 .for_each(|(t, w)| {
881 t.predict_contributions_row_average(&r_, c, w, &self.missing);
882 });
883 });
884 } else {
885 data.index
886 .iter()
887 .zip(contribs.chunks_mut(data.cols + 1))
888 .for_each(|(row, c)| {
889 let r_ = data.get_row(*row);
890 self.get_prediction_trees()
891 .iter()
892 .zip(weights.iter())
893 .for_each(|(t, w)| {
894 t.predict_contributions_row_average(&r_, c, w, &self.missing);
895 });
896 });
897 }
898
899 contribs
900 }
901
902 fn predict_contributions_probability_change(
903 &self,
904 data: &Matrix<f64>,
905 parallel: bool,
906 ) -> Vec<f64> {
907 let mut contribs = vec![0.; (data.cols + 1) * data.rows];
908 let bias_idx = data.cols + 1;
909 contribs
910 .iter_mut()
911 .skip(bias_idx - 1)
912 .step_by(bias_idx)
913 .for_each(|v| *v += odds(self.base_score));
914
915 if parallel {
916 data.index
917 .par_iter()
918 .zip(contribs.par_chunks_mut(data.cols + 1))
919 .for_each(|(row, c)| {
920 let r_ = data.get_row(*row);
921 self.get_prediction_trees()
922 .iter()
923 .fold(self.base_score, |acc, t| {
924 t.predict_contributions_row_probability_change(
925 &r_,
926 c,
927 &self.missing,
928 acc,
929 )
930 });
931 });
932 } else {
933 data.index
934 .iter()
935 .zip(contribs.chunks_mut(data.cols + 1))
936 .for_each(|(row, c)| {
937 let r_ = data.get_row(*row);
938 self.get_prediction_trees()
939 .iter()
940 .fold(self.base_score, |acc, t| {
941 t.predict_contributions_row_probability_change(
942 &r_,
943 c,
944 &self.missing,
945 acc,
946 )
947 });
948 });
949 }
950 contribs
951 }
952
953 pub fn value_partial_dependence(&self, feature: usize, value: f64) -> f64 {
959 let pd: f64 = if self.parallel {
960 self.get_prediction_trees()
961 .par_iter()
962 .map(|t| t.value_partial_dependence(feature, value, &self.missing))
963 .sum()
964 } else {
965 self.get_prediction_trees()
966 .iter()
967 .map(|t| t.value_partial_dependence(feature, value, &self.missing))
968 .sum()
969 };
970 pd + self.base_score
971 }
972
973 pub fn calculate_feature_importance(
978 &self,
979 method: ImportanceMethod,
980 normalize: bool,
981 ) -> HashMap<usize, f32> {
982 let (average, importance_fn): (bool, ImportanceFn) = match method {
983 ImportanceMethod::Weight => (false, Tree::calculate_importance_weight),
984 ImportanceMethod::Gain => (true, Tree::calculate_importance_gain),
985 ImportanceMethod::TotalGain => (false, Tree::calculate_importance_gain),
986 ImportanceMethod::Cover => (true, Tree::calculate_importance_cover),
987 ImportanceMethod::TotalCover => (false, Tree::calculate_importance_cover),
988 };
989 let mut stats = HashMap::new();
990 for tree in self.trees.iter() {
991 importance_fn(tree, &mut stats)
992 }
993
994 let importance = stats
995 .iter()
996 .map(|(k, (v, c))| {
997 if average {
998 (*k, v / (*c as f32))
999 } else {
1000 (*k, *v)
1001 }
1002 })
1003 .collect::<HashMap<usize, f32>>();
1004
1005 if normalize {
1006 let mut values: Vec<f32> = importance.values().copied().collect();
1010 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1012 let total: f32 = values.iter().sum();
1013 importance.iter().map(|(k, v)| (*k, v / total)).collect()
1014 } else {
1015 importance
1016 }
1017 }
1018
1019 pub fn save_booster(&self, path: &str) -> Result<(), ForustError> {
1023 let model = self.json_dump()?;
1024 match fs::write(path, model) {
1025 Err(e) => Err(ForustError::UnableToWrite(e.to_string())),
1026 Ok(_) => Ok(()),
1027 }
1028 }
1029
1030 pub fn json_dump(&self) -> Result<String, ForustError> {
1032 match serde_json::to_string(self) {
1033 Ok(s) => Ok(s),
1034 Err(e) => Err(ForustError::UnableToWrite(e.to_string())),
1035 }
1036 }
1037
1038 pub fn from_json(json_str: &str) -> Result<Self, ForustError> {
1042 let model = serde_json::from_str::<GradientBooster>(json_str);
1043 match model {
1044 Ok(m) => Ok(m),
1045 Err(e) => Err(ForustError::UnableToRead(e.to_string())),
1046 }
1047 }
1048
1049 pub fn load_booster(path: &str) -> Result<Self, ForustError> {
1053 let json_str = match fs::read_to_string(path) {
1054 Ok(s) => Ok(s),
1055 Err(e) => Err(ForustError::UnableToRead(e.to_string())),
1056 }?;
1057 Self::from_json(&json_str)
1058 }
1059
1060 pub fn set_objective_type(mut self, objective_type: ObjectiveType) -> Self {
1064 self.objective_type = objective_type;
1065 self
1066 }
1067
1068 pub fn set_iterations(mut self, iterations: usize) -> Self {
1071 self.iterations = iterations;
1072 self
1073 }
1074
1075 pub fn set_learning_rate(mut self, learning_rate: f32) -> Self {
1078 self.learning_rate = learning_rate;
1079 self
1080 }
1081
1082 pub fn set_max_depth(mut self, max_depth: usize) -> Self {
1085 self.max_depth = max_depth;
1086 self
1087 }
1088
1089 pub fn set_max_leaves(mut self, max_leaves: usize) -> Self {
1092 self.max_leaves = max_leaves;
1093 self
1094 }
1095
1096 pub fn set_nbins(mut self, nbins: u16) -> Self {
1102 self.nbins = nbins;
1103 self
1104 }
1105
1106 pub fn set_l1(mut self, l1: f32) -> Self {
1109 self.l1 = l1;
1110 self
1111 }
1112
1113 pub fn set_l2(mut self, l2: f32) -> Self {
1116 self.l2 = l2;
1117 self
1118 }
1119
1120 pub fn set_gamma(mut self, gamma: f32) -> Self {
1123 self.gamma = gamma;
1124 self
1125 }
1126
1127 pub fn set_max_delta_step(mut self, max_delta_step: f32) -> Self {
1130 self.max_delta_step = max_delta_step;
1131 self
1132 }
1133
1134 pub fn set_min_leaf_weight(mut self, min_leaf_weight: f32) -> Self {
1138 self.min_leaf_weight = min_leaf_weight;
1139 self
1140 }
1141
1142 pub fn set_base_score(mut self, base_score: f64) -> Self {
1145 self.base_score = base_score;
1146 self
1147 }
1148
1149 pub fn set_initialize_base_score(mut self, initialize_base_score: bool) -> Self {
1152 self.initialize_base_score = initialize_base_score;
1153 self
1154 }
1155
1156 pub fn set_parallel(mut self, parallel: bool) -> Self {
1159 self.parallel = parallel;
1160 self
1161 }
1162
1163 pub fn set_allow_missing_splits(mut self, allow_missing_splits: bool) -> Self {
1166 self.allow_missing_splits = allow_missing_splits;
1167 self
1168 }
1169
1170 pub fn set_monotone_constraints(mut self, monotone_constraints: Option<ConstraintMap>) -> Self {
1173 self.monotone_constraints = monotone_constraints;
1174 self
1175 }
1176
1177 pub fn set_subsample(mut self, subsample: f32) -> Self {
1180 self.subsample = subsample;
1181 self
1182 }
1183
1184 pub fn set_colsample_bytree(mut self, colsample_bytree: f64) -> Self {
1187 self.colsample_bytree = colsample_bytree;
1188 self
1189 }
1190
1191 pub fn set_seed(mut self, seed: u64) -> Self {
1194 self.seed = seed;
1195 self
1196 }
1197
1198 pub fn set_missing(mut self, missing: f64) -> Self {
1201 self.missing = missing;
1202 self
1203 }
1204
1205 pub fn set_create_missing_branch(mut self, create_missing_branch: bool) -> Self {
1209 self.create_missing_branch = create_missing_branch;
1210 self
1211 }
1212
1213 pub fn set_sample_method(mut self, sample_method: SampleMethod) -> Self {
1216 self.sample_method = sample_method;
1217 self
1218 }
1219
1220 pub fn set_evaluation_metric(mut self, evaluation_metric: Option<Metric>) -> Self {
1223 self.evaluation_metric = evaluation_metric;
1224 self
1225 }
1226
1227 pub fn set_early_stopping_rounds(mut self, early_stopping_rounds: Option<usize>) -> Self {
1230 self.early_stopping_rounds = early_stopping_rounds;
1231 self
1232 }
1233
1234 pub fn set_prediction_iteration(mut self, prediction_iteration: Option<usize>) -> Self {
1237 self.prediction_iteration = prediction_iteration.map(|i| i + 1);
1238 self
1239 }
1240
1241 pub fn set_terminate_missing_features(
1247 mut self,
1248 terminate_missing_features: HashSet<usize>,
1249 ) -> Self {
1250 self.terminate_missing_features = terminate_missing_features;
1251 self
1252 }
1253
1254 pub fn insert_metadata(&mut self, key: String, value: String) {
1258 self.metadata.insert(key, value);
1259 }
1260
1261 pub fn get_metadata(&self, key: &String) -> Option<String> {
1264 self.metadata.get(key).cloned()
1265 }
1266}
1267
1268#[cfg(test)]
1269mod tests {
1270 use super::*;
1271 use std::fs;
1272
1273 #[test]
1274 fn test_booster_fit_subsample() {
1275 let file = fs::read_to_string("resources/contiguous_with_missing.csv")
1276 .expect("Something went wrong reading the file");
1277 let data_vec: Vec<f64> = file
1278 .lines()
1279 .map(|x| x.parse::<f64>().unwrap_or(f64::NAN))
1280 .collect();
1281 let file = fs::read_to_string("resources/performance.csv")
1282 .expect("Something went wrong reading the file");
1283 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1284
1285 let data = Matrix::new(&data_vec, 891, 5);
1286 let mut booster = GradientBooster::default()
1288 .set_iterations(10)
1289 .set_nbins(300)
1290 .set_max_depth(3)
1291 .set_subsample(0.5)
1292 .set_base_score(0.5)
1293 .set_initialize_base_score(false);
1294 let sample_weight = vec![1.; y.len()];
1295 booster.fit(&data, &y, &sample_weight, None).unwrap();
1296 let preds = booster.predict(&data, false);
1297 let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
1298 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
1299 println!("{}", booster.trees[0]);
1300 println!("{}", booster.trees[0].nodes.len());
1301 println!("{}", booster.trees.last().unwrap().nodes.len());
1302 println!("{:?}", &preds[0..10]);
1303 }
1304
1305 #[test]
1306 fn test_booster_fit() {
1307 let file = fs::read_to_string("resources/contiguous_with_missing.csv")
1308 .expect("Something went wrong reading the file");
1309 let data_vec: Vec<f64> = file
1310 .lines()
1311 .map(|x| x.parse::<f64>().unwrap_or(f64::NAN))
1312 .collect();
1313 let file = fs::read_to_string("resources/performance.csv")
1314 .expect("Something went wrong reading the file");
1315 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1316
1317 let data = Matrix::new(&data_vec, 891, 5);
1318 let mut booster = GradientBooster::default()
1320 .set_iterations(10)
1321 .set_nbins(300)
1322 .set_max_depth(3)
1323 .set_base_score(0.5)
1324 .set_initialize_base_score(false);
1325
1326 let sample_weight = vec![1.; y.len()];
1327 booster.fit(&data, &y, &sample_weight, None).unwrap();
1328 let preds = booster.predict(&data, false);
1329 let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
1330 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
1331 println!("{}", booster.trees[0]);
1332 println!("{}", booster.trees[0].nodes.len());
1333 println!("{}", booster.trees.last().unwrap().nodes.len());
1334 println!("{:?}", &preds[0..10]);
1335 }
1336
1337 #[test]
1338 fn test_booster_fit_nofitted_base_score() {
1339 let file = fs::read_to_string("resources/contiguous_with_missing.csv")
1340 .expect("Something went wrong reading the file");
1341 let data_vec: Vec<f64> = file
1342 .lines()
1343 .map(|x| x.parse::<f64>().unwrap_or(f64::NAN))
1344 .collect();
1345 let file = fs::read_to_string("resources/performance-fare.csv")
1346 .expect("Something went wrong reading the file");
1347 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1348
1349 let data = Matrix::new(&data_vec, 891, 5);
1350 let mut booster = GradientBooster::default()
1352 .set_objective_type(ObjectiveType::SquaredLoss)
1353 .set_iterations(10)
1354 .set_nbins(300)
1355 .set_max_depth(3)
1356 .set_initialize_base_score(true);
1357 let sample_weight = vec![1.; y.len()];
1358 booster.fit(&data, &y, &sample_weight, None).unwrap();
1359 let preds = booster.predict(&data, false);
1360 let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
1361 assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
1362 println!("{}", booster.trees[0]);
1363 println!("{}", booster.trees[0].nodes.len());
1364 println!("{}", booster.trees.last().unwrap().nodes.len());
1365 println!("{:?}", &preds[0..10]);
1366 }
1367
1368 #[test]
1369 fn test_tree_save() {
1370 let file = fs::read_to_string("resources/contiguous_with_missing.csv")
1371 .expect("Something went wrong reading the file");
1372 let data_vec: Vec<f64> = file
1373 .lines()
1374 .map(|x| x.parse::<f64>().unwrap_or(f64::NAN))
1375 .collect();
1376 let file = fs::read_to_string("resources/performance.csv")
1377 .expect("Something went wrong reading the file");
1378 let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
1379
1380 let data = Matrix::new(&data_vec, 891, 5);
1381 let mut booster = GradientBooster::default()
1383 .set_iterations(10)
1384 .set_nbins(300)
1385 .set_max_depth(3)
1386 .set_base_score(0.5)
1387 .set_initialize_base_score(false);
1388 let sample_weight = vec![1.; y.len()];
1389 booster.fit(&data, &y, &sample_weight, None).unwrap();
1390 let preds = booster.predict(&data, true);
1391
1392 booster.save_booster("resources/model64.json").unwrap();
1393 let booster2 = GradientBooster::load_booster("resources/model64.json").unwrap();
1394 assert_eq!(booster2.predict(&data, true)[0..10], preds[0..10]);
1395
1396 booster.missing = 0.;
1398 booster.save_booster("resources/modelmissing.json").unwrap();
1399 let booster3 = GradientBooster::load_booster("resources/modelmissing.json").unwrap();
1400 assert_eq!(booster3.missing, 0.);
1401 assert_eq!(booster3.missing, booster.missing);
1402 }
1403}