1pub type LossMonitor<D> = Box<dyn Fn(&D, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync>;
3
4use crate::dist::categorical::{Bernoulli, Categorical};
5use crate::dist::normal::Normal;
6use crate::dist::{ClassificationDistn, Distribution};
7use crate::learners::{default_tree_learner, BaseLearner, DecisionTreeLearner, TrainedBaseLearner};
8use crate::scores::{LogScore, Scorable, Score};
9use ndarray::{Array1, Array2};
10use rand::prelude::*;
11use rand::rngs::StdRng;
12use rand::SeedableRng;
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18#[derive(Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize)]
20pub enum LearningRateSchedule {
21 #[default]
23 Constant,
24 Linear {
27 decay_rate: f64,
28 min_lr_fraction: f64,
29 },
30 Exponential { decay_rate: f64 },
32 Cosine,
35 CosineWarmRestarts { restart_period: u32 },
38}
39
40#[derive(Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize)]
42pub enum LineSearchMethod {
43 #[default]
46 Binary,
47 GoldenSection {
51 max_iters: usize,
53 },
54}
55
56const GOLDEN_RATIO: f64 = 1.618033988749895;
58
59#[derive(Debug, Clone, Default)]
61pub struct EvalsResult {
62 pub train: Vec<f64>,
64 pub val: Vec<f64>,
66}
67
68#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
71pub struct NGBoostParams {
72 pub n_estimators: u32,
74 pub learning_rate: f64,
76 pub natural_gradient: bool,
78 pub minibatch_frac: f64,
80 pub col_sample: f64,
82 pub verbose: bool,
84 pub verbose_eval: f64,
86 pub tol: f64,
88 pub early_stopping_rounds: Option<u32>,
90 pub validation_fraction: f64,
92 pub random_state: Option<u64>,
94 pub lr_schedule: LearningRateSchedule,
96 pub tikhonov_reg: f64,
98 pub line_search_method: LineSearchMethod,
100}
101
102impl Default for NGBoostParams {
103 fn default() -> Self {
104 Self {
105 n_estimators: 500,
106 learning_rate: 0.01,
107 natural_gradient: true,
108 minibatch_frac: 1.0,
109 col_sample: 1.0,
110 verbose: false,
111 verbose_eval: 1.0,
112 tol: 1e-4,
113 early_stopping_rounds: None,
114 validation_fraction: 0.1,
115 random_state: None,
116 lr_schedule: LearningRateSchedule::Constant,
117 tikhonov_reg: 0.0,
118 line_search_method: LineSearchMethod::Binary,
119 }
120 }
121}
122
123pub struct NGBoost<D, S, B>
124where
125 D: Distribution + Scorable<S> + Clone,
126 S: Score,
127 B: BaseLearner + Clone,
128{
129 pub n_estimators: u32,
131 pub learning_rate: f64,
132 pub natural_gradient: bool,
133 pub minibatch_frac: f64,
134 pub col_sample: f64,
135 pub verbose: bool,
136 pub verbose_eval: f64,
141 pub tol: f64,
142 pub early_stopping_rounds: Option<u32>,
143 pub validation_fraction: f64,
144 pub adaptive_learning_rate: bool, pub lr_schedule: LearningRateSchedule,
147 pub tikhonov_reg: f64,
151 pub line_search_method: LineSearchMethod,
153
154 base_learner: B,
156
157 pub base_models: Vec<Vec<Box<dyn TrainedBaseLearner>>>,
159 pub scalings: Vec<f64>,
160 pub init_params: Option<Array1<f64>>,
161 pub col_idxs: Vec<Vec<usize>>,
162 train_loss_monitor: Option<LossMonitor<D>>,
163 val_loss_monitor: Option<LossMonitor<D>>,
164 best_val_loss_itr: Option<usize>,
165 n_features: Option<usize>,
166 pub evals_result: EvalsResult,
168
169 rng: StdRng,
171 random_state: Option<u64>,
173
174 _dist: PhantomData<D>,
176 _score: PhantomData<S>,
177}
178
179impl<D, S, B> NGBoost<D, S, B>
180where
181 D: Distribution + Scorable<S> + Clone,
182 S: Score,
183 B: BaseLearner + Clone,
184{
185 pub fn new(n_estimators: u32, learning_rate: f64, base_learner: B) -> Self {
186 NGBoost {
187 n_estimators,
188 learning_rate,
189 natural_gradient: true,
190 minibatch_frac: 1.0,
191 col_sample: 1.0,
192 verbose: false,
193 verbose_eval: 100.0,
194 tol: 1e-4,
195 early_stopping_rounds: None,
196 validation_fraction: 0.1,
197 adaptive_learning_rate: false,
198 lr_schedule: LearningRateSchedule::Constant,
199 tikhonov_reg: 0.0,
200 line_search_method: LineSearchMethod::Binary,
201 base_learner,
202 base_models: Vec::new(),
203 scalings: Vec::new(),
204 init_params: None,
205 col_idxs: Vec::new(),
206 train_loss_monitor: None,
207 val_loss_monitor: None,
208 best_val_loss_itr: None,
209 n_features: None,
210 evals_result: EvalsResult::default(),
211 rng: StdRng::from_rng(&mut rand::rng()),
212 random_state: None,
213 _dist: PhantomData,
214 _score: PhantomData,
215 }
216 }
217
218 pub fn with_seed(n_estimators: u32, learning_rate: f64, base_learner: B, seed: u64) -> Self {
221 NGBoost {
222 n_estimators,
223 learning_rate,
224 natural_gradient: true,
225 minibatch_frac: 1.0,
226 col_sample: 1.0,
227 verbose: false,
228 verbose_eval: 100.0,
229 tol: 1e-4,
230 early_stopping_rounds: None,
231 validation_fraction: 0.1,
232 adaptive_learning_rate: false,
233 lr_schedule: LearningRateSchedule::Constant,
234 tikhonov_reg: 0.0,
235 line_search_method: LineSearchMethod::Binary,
236 base_learner,
237 base_models: Vec::new(),
238 scalings: Vec::new(),
239 init_params: None,
240 col_idxs: Vec::new(),
241 train_loss_monitor: None,
242 val_loss_monitor: None,
243 best_val_loss_itr: None,
244 n_features: None,
245 evals_result: EvalsResult::default(),
246 rng: StdRng::seed_from_u64(seed),
247 random_state: Some(seed),
248 _dist: PhantomData,
249 _score: PhantomData,
250 }
251 }
252
253 pub fn set_random_state(&mut self, seed: u64) {
256 self.random_state = Some(seed);
257 self.rng = StdRng::seed_from_u64(seed);
258 }
259
260 pub fn random_state(&self) -> Option<u64> {
262 self.random_state
263 }
264
265 pub fn evals_result(&self) -> &EvalsResult {
267 &self.evals_result
268 }
269
270 pub fn with_options(
271 n_estimators: u32,
272 learning_rate: f64,
273 base_learner: B,
274 natural_gradient: bool,
275 minibatch_frac: f64,
276 col_sample: f64,
277 verbose: bool,
278 verbose_eval: f64,
279 tol: f64,
280 early_stopping_rounds: Option<u32>,
281 validation_fraction: f64,
282 adaptive_learning_rate: bool,
283 ) -> Self {
284 NGBoost {
285 n_estimators,
286 learning_rate,
287 natural_gradient,
288 minibatch_frac,
289 col_sample,
290 verbose,
291 verbose_eval,
292 tol,
293 early_stopping_rounds,
294 validation_fraction,
295 adaptive_learning_rate,
296 lr_schedule: LearningRateSchedule::Constant,
297 tikhonov_reg: 0.0,
298 line_search_method: LineSearchMethod::Binary,
299 base_learner,
300 base_models: Vec::new(),
301 scalings: Vec::new(),
302 init_params: None,
303 col_idxs: Vec::new(),
304 train_loss_monitor: None,
305 val_loss_monitor: None,
306 best_val_loss_itr: None,
307 n_features: None,
308 evals_result: EvalsResult::default(),
309 rng: StdRng::from_rng(&mut rand::rng()),
310 random_state: None,
311 _dist: PhantomData,
312 _score: PhantomData,
313 }
314 }
315
316 #[allow(clippy::too_many_arguments)]
318 pub fn with_options_seeded(
319 n_estimators: u32,
320 learning_rate: f64,
321 base_learner: B,
322 natural_gradient: bool,
323 minibatch_frac: f64,
324 col_sample: f64,
325 verbose: bool,
326 verbose_eval: f64,
327 tol: f64,
328 early_stopping_rounds: Option<u32>,
329 validation_fraction: f64,
330 adaptive_learning_rate: bool,
331 random_state: Option<u64>,
332 ) -> Self {
333 let rng = match random_state {
334 Some(seed) => StdRng::seed_from_u64(seed),
335 None => StdRng::from_rng(&mut rand::rng()),
336 };
337 NGBoost {
338 n_estimators,
339 learning_rate,
340 natural_gradient,
341 minibatch_frac,
342 col_sample,
343 verbose,
344 verbose_eval,
345 tol,
346 early_stopping_rounds,
347 validation_fraction,
348 adaptive_learning_rate,
349 lr_schedule: LearningRateSchedule::Constant,
350 tikhonov_reg: 0.0,
351 line_search_method: LineSearchMethod::Binary,
352 base_learner,
353 base_models: Vec::new(),
354 scalings: Vec::new(),
355 init_params: None,
356 col_idxs: Vec::new(),
357 train_loss_monitor: None,
358 val_loss_monitor: None,
359 best_val_loss_itr: None,
360 n_features: None,
361 evals_result: EvalsResult::default(),
362 rng,
363 random_state,
364 _dist: PhantomData,
365 _score: PhantomData,
366 }
367 }
368
369 #[allow(clippy::too_many_arguments)]
371 pub fn with_advanced_options(
372 n_estimators: u32,
373 learning_rate: f64,
374 base_learner: B,
375 natural_gradient: bool,
376 minibatch_frac: f64,
377 col_sample: f64,
378 verbose: bool,
379 verbose_eval: f64,
380 tol: f64,
381 early_stopping_rounds: Option<u32>,
382 validation_fraction: f64,
383 lr_schedule: LearningRateSchedule,
384 tikhonov_reg: f64,
385 line_search_method: LineSearchMethod,
386 ) -> Self {
387 NGBoost {
388 n_estimators,
389 learning_rate,
390 natural_gradient,
391 minibatch_frac,
392 col_sample,
393 verbose,
394 verbose_eval,
395 tol,
396 early_stopping_rounds,
397 validation_fraction,
398 adaptive_learning_rate: false,
399 lr_schedule,
400 tikhonov_reg,
401 line_search_method,
402 base_learner,
403 base_models: Vec::new(),
404 scalings: Vec::new(),
405 init_params: None,
406 col_idxs: Vec::new(),
407 train_loss_monitor: None,
408 val_loss_monitor: None,
409 best_val_loss_itr: None,
410 n_features: None,
411 evals_result: EvalsResult::default(),
412 rng: StdRng::from_rng(&mut rand::rng()),
413 random_state: None,
414 _dist: PhantomData,
415 _score: PhantomData,
416 }
417 }
418
419 #[allow(clippy::too_many_arguments)]
421 pub fn with_full_options(
422 n_estimators: u32,
423 learning_rate: f64,
424 base_learner: B,
425 natural_gradient: bool,
426 minibatch_frac: f64,
427 col_sample: f64,
428 verbose: bool,
429 verbose_eval: f64,
430 tol: f64,
431 early_stopping_rounds: Option<u32>,
432 validation_fraction: f64,
433 lr_schedule: LearningRateSchedule,
434 tikhonov_reg: f64,
435 line_search_method: LineSearchMethod,
436 random_state: Option<u64>,
437 ) -> Self {
438 let rng = match random_state {
439 Some(seed) => StdRng::seed_from_u64(seed),
440 None => StdRng::from_rng(&mut rand::rng()),
441 };
442 NGBoost {
443 n_estimators,
444 learning_rate,
445 natural_gradient,
446 minibatch_frac,
447 col_sample,
448 verbose,
449 verbose_eval,
450 tol,
451 early_stopping_rounds,
452 validation_fraction,
453 adaptive_learning_rate: false,
454 lr_schedule,
455 tikhonov_reg,
456 line_search_method,
457 base_learner,
458 base_models: Vec::new(),
459 scalings: Vec::new(),
460 init_params: None,
461 col_idxs: Vec::new(),
462 train_loss_monitor: None,
463 val_loss_monitor: None,
464 best_val_loss_itr: None,
465 n_features: None,
466 evals_result: EvalsResult::default(),
467 rng,
468 random_state,
469 _dist: PhantomData,
470 _score: PhantomData,
471 }
472 }
473
474 pub fn set_train_loss_monitor(&mut self, monitor: LossMonitor<D>) {
476 self.train_loss_monitor = Some(monitor);
477 }
478
479 pub fn set_val_loss_monitor(&mut self, monitor: LossMonitor<D>) {
481 self.val_loss_monitor = Some(monitor);
482 }
483
484 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
485 self.fit_with_validation(x, y, None, None, None, None)
486 }
487
488 pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
495 self.partial_fit_with_validation(x, y, None, None, None, None)
496 }
497
498 pub fn partial_fit_with_validation(
500 &mut self,
501 x: &Array2<f64>,
502 y: &Array1<f64>,
503 x_val: Option<&Array2<f64>>,
504 y_val: Option<&Array1<f64>>,
505 sample_weight: Option<&Array1<f64>>,
506 val_sample_weight: Option<&Array1<f64>>,
507 ) -> Result<(), &'static str> {
508 self.fit_internal(x, y, x_val, y_val, sample_weight, val_sample_weight, false)
510 }
511
512 pub fn fit_with_validation(
513 &mut self,
514 x: &Array2<f64>,
515 y: &Array1<f64>,
516 x_val: Option<&Array2<f64>>,
517 y_val: Option<&Array1<f64>>,
518 sample_weight: Option<&Array1<f64>>,
519 val_sample_weight: Option<&Array1<f64>>,
520 ) -> Result<(), &'static str> {
521 self.fit_internal(x, y, x_val, y_val, sample_weight, val_sample_weight, true)
522 }
523
524 fn validate_hyperparameters(&self) -> Result<(), &'static str> {
527 if self.n_estimators == 0 {
528 return Err("n_estimators must be greater than 0");
529 }
530 if self.learning_rate <= 0.0 {
531 return Err("learning_rate must be positive");
532 }
533 if self.learning_rate > 10.0 {
534 return Err("learning_rate > 10.0 is likely a mistake");
535 }
536 if self.minibatch_frac <= 0.0 || self.minibatch_frac > 1.0 {
537 return Err("minibatch_frac must be in (0, 1]");
538 }
539 if self.col_sample <= 0.0 || self.col_sample > 1.0 {
540 return Err("col_sample must be in (0, 1]");
541 }
542 if self.tol < 0.0 {
543 return Err("tol must be non-negative");
544 }
545 if self.validation_fraction < 0.0 || self.validation_fraction >= 1.0 {
546 return Err("validation_fraction must be in [0, 1)");
547 }
548 if self.tikhonov_reg < 0.0 {
549 return Err("tikhonov_reg must be non-negative");
550 }
551
552 match self.lr_schedule {
554 LearningRateSchedule::Linear {
555 decay_rate,
556 min_lr_fraction,
557 } => {
558 if decay_rate < 0.0 || decay_rate > 1.0 {
559 return Err("Linear schedule decay_rate must be in [0, 1]");
560 }
561 if min_lr_fraction < 0.0 || min_lr_fraction > 1.0 {
562 return Err("Linear schedule min_lr_fraction must be in [0, 1]");
563 }
564 }
565 LearningRateSchedule::Exponential { decay_rate } => {
566 if decay_rate < 0.0 {
567 return Err("Exponential schedule decay_rate must be non-negative");
568 }
569 }
570 LearningRateSchedule::CosineWarmRestarts { restart_period } => {
571 if restart_period == 0 {
572 return Err("CosineWarmRestarts restart_period must be > 0");
573 }
574 }
575 _ => {}
576 }
577
578 if let LineSearchMethod::GoldenSection { max_iters } = self.line_search_method {
580 if max_iters == 0 {
581 return Err("GoldenSection max_iters must be > 0");
582 }
583 }
584
585 Ok(())
586 }
587
588 fn fit_internal(
590 &mut self,
591 x: &Array2<f64>,
592 y: &Array1<f64>,
593 x_val: Option<&Array2<f64>>,
594 y_val: Option<&Array1<f64>>,
595 sample_weight: Option<&Array1<f64>>,
596 val_sample_weight: Option<&Array1<f64>>,
597 reset_state: bool,
598 ) -> Result<(), &'static str> {
599 self.validate_hyperparameters()?;
601
602 if x.nrows() != y.len() {
604 return Err("Number of samples in X and y must match");
605 }
606 if x.nrows() == 0 {
607 return Err("Cannot fit to empty dataset");
608 }
609 if x.ncols() == 0 {
610 return Err("Cannot fit to dataset with no features");
611 }
612
613 if x.iter().any(|&v| !v.is_finite()) {
615 return Err("Input X contains NaN or infinite values");
616 }
617 if y.iter().any(|&v| !v.is_finite()) {
618 return Err("Input y contains NaN or infinite values");
619 }
620
621 if reset_state {
623 self.base_models.clear();
624 self.scalings.clear();
625 self.col_idxs.clear();
626 self.best_val_loss_itr = None;
627 self.evals_result = EvalsResult::default();
628 }
629 self.n_features = Some(x.ncols());
630
631 let (x_train, y_train, x_val_auto, y_val_auto) = if self.early_stopping_rounds.is_some()
633 && x_val.is_none()
634 && y_val.is_none()
635 && self.validation_fraction > 0.0
636 && self.validation_fraction < 1.0
637 {
638 let n_samples = x.nrows();
641 let n_val = ((n_samples as f64) * self.validation_fraction) as usize;
642 let n_train = n_samples - n_val;
643
644 let mut indices: Vec<usize> = (0..n_samples).collect();
646 for i in (1..indices.len()).rev() {
647 let j = self.rng.random_range(0..=i);
648 indices.swap(i, j);
649 }
650
651 let train_indices: Vec<usize> = indices[0..n_train].to_vec();
652 let val_indices: Vec<usize> = indices[n_train..].to_vec();
653
654 let x_train = x.select(ndarray::Axis(0), &train_indices);
655 let y_train = y.select(ndarray::Axis(0), &train_indices);
656 let x_val_auto = Some(x.select(ndarray::Axis(0), &val_indices));
657 let y_val_auto = Some(y.select(ndarray::Axis(0), &val_indices));
658
659 (x_train, y_train, x_val_auto, y_val_auto)
660 } else {
661 (x.to_owned(), y.to_owned(), x_val.cloned(), y_val.cloned())
662 };
663
664 let x_train = x_train;
666 let y_train = y_train;
667 let x_val = x_val_auto.as_ref().or(x_val);
668 let y_val = y_val_auto.as_ref().or(y_val);
669
670 if let (Some(xv), Some(yv)) = (x_val, y_val) {
672 if xv.nrows() != yv.len() {
673 return Err("Number of samples in validation X and y must match");
674 }
675 if xv.ncols() != x_train.ncols() {
676 return Err("Number of features in training and validation data must match");
677 }
678 }
679
680 self.init_params = Some(D::fit(&y_train));
681 let n_params = self.init_params.as_ref().unwrap().len();
682 let mut params = Array2::from_elem((x_train.nrows(), n_params), 0.0);
683
684 let init_params = self.init_params.as_ref().unwrap();
686 params
687 .outer_iter_mut()
688 .for_each(|mut row| row.assign(init_params));
689
690 let mut val_params = if let (Some(xv), Some(_yv)) = (x_val, y_val) {
692 let mut v_params = Array2::from_elem((xv.nrows(), n_params), 0.0);
693 v_params
694 .outer_iter_mut()
695 .for_each(|mut row| row.assign(init_params));
696 Some(v_params)
697 } else {
698 None
699 };
700
701 let mut best_val_loss = f64::INFINITY;
702 let mut best_iter = 0;
703 let mut no_improvement_count = 0;
704
705 for itr in 0..self.n_estimators {
706 let dist = D::from_params(¶ms);
707
708 let grads = if self.natural_gradient && self.tikhonov_reg > 0.0 {
710 let standard_grad = Scorable::d_score(&dist, &y_train);
712 let metric = Scorable::metric(&dist);
713 crate::scores::natural_gradient_regularized(
714 &standard_grad,
715 &metric,
716 self.tikhonov_reg,
717 )
718 } else {
719 Scorable::grad(&dist, &y_train, self.natural_gradient)
720 };
721
722 let (row_idxs, col_idxs, x_sampled, y_sampled, params_sampled, weight_sampled) =
724 self.sample(&x_train, &y_train, ¶ms, sample_weight);
725 self.col_idxs.push(col_idxs.clone());
726
727 let grads_sampled = grads.select(ndarray::Axis(0), &row_idxs);
728
729 #[cfg(feature = "parallel")]
731 let fit_results: Vec<
732 Result<(Box<dyn TrainedBaseLearner>, Array1<f64>), &'static str>,
733 > = {
734 let learners: Vec<B> = (0..n_params).map(|_| self.base_learner.clone()).collect();
736 learners
737 .into_par_iter()
738 .enumerate()
739 .map(|(j, learner)| {
740 let grad_j = grads_sampled.column(j).to_owned();
741 let fitted = learner.fit_with_weights(
742 &x_sampled,
743 &grad_j,
744 weight_sampled.as_ref(),
745 )?;
746 let preds = fitted.predict(&x_sampled);
747 Ok((fitted, preds))
748 })
749 .collect()
750 };
751
752 #[cfg(not(feature = "parallel"))]
753 let fit_results: Vec<
754 Result<(Box<dyn TrainedBaseLearner>, Array1<f64>), &'static str>,
755 > = (0..n_params)
756 .map(|j| {
757 let grad_j = grads_sampled.column(j).to_owned();
758 let learner = self.base_learner.clone();
759 let fitted =
760 learner.fit_with_weights(&x_sampled, &grad_j, weight_sampled.as_ref())?;
761 let preds = fitted.predict(&x_sampled);
762 Ok((fitted, preds))
763 })
764 .collect();
765
766 let mut fitted_learners: Vec<Box<dyn TrainedBaseLearner>> =
768 Vec::with_capacity(n_params);
769 let mut predictions_cols: Vec<Array1<f64>> = Vec::with_capacity(n_params);
770 for result in fit_results {
771 let (fitted, preds) = result?;
772 fitted_learners.push(fitted);
773 predictions_cols.push(preds);
774 }
775
776 let predictions = to_2d_array(predictions_cols);
777
778 let scale = self.line_search(
779 &predictions,
780 ¶ms_sampled,
781 &y_sampled,
782 weight_sampled.as_ref(),
783 );
784 self.scalings.push(scale);
785 self.base_models.push(fitted_learners);
786
787 let progress = itr as f64 / self.n_estimators as f64;
789 let effective_learning_rate = self.compute_learning_rate(itr, progress);
790
791 let fitted_learners = self.base_models.last().unwrap();
796 let full_predictions_cols: Vec<Array1<f64>> = if col_idxs.len() == x_train.ncols() {
797 fitted_learners
798 .iter()
799 .map(|learner| learner.predict(&x_train))
800 .collect()
801 } else {
802 let x_subset = x_train.select(ndarray::Axis(1), &col_idxs);
803 fitted_learners
804 .iter()
805 .map(|learner| learner.predict(&x_subset))
806 .collect()
807 };
808 let full_predictions = to_2d_array(full_predictions_cols);
809
810 params -= &(effective_learning_rate * scale * &full_predictions);
811
812 if let (Some(xv), Some(yv), Some(vp)) = (x_val, y_val, val_params.as_mut()) {
814 let fitted_learners = self.base_models.last().unwrap();
817 let val_predictions_cols: Vec<Array1<f64>> = if col_idxs.len() == xv.ncols() {
818 fitted_learners
819 .iter()
820 .map(|learner| learner.predict(xv))
821 .collect()
822 } else {
823 let xv_subset = xv.select(ndarray::Axis(1), &col_idxs);
824 fitted_learners
825 .iter()
826 .map(|learner| learner.predict(&xv_subset))
827 .collect()
828 };
829 let val_predictions = to_2d_array(val_predictions_cols);
830 *vp -= &(effective_learning_rate * scale * &val_predictions);
831
832 let val_dist = D::from_params(vp);
834 let val_loss = if let Some(monitor) = &self.val_loss_monitor {
835 monitor(&val_dist, yv, val_sample_weight)
836 } else {
837 Scorable::total_score(&val_dist, yv, val_sample_weight)
838 };
839
840 self.evals_result.val.push(val_loss);
842
843 if val_loss < best_val_loss {
845 best_val_loss = val_loss;
846 best_iter = itr;
847 no_improvement_count = 0;
848 self.best_val_loss_itr = Some(itr as usize);
849 } else {
850 no_improvement_count += 1;
851 }
852
853 if let Some(rounds) = self.early_stopping_rounds {
855 if no_improvement_count >= rounds {
856 if self.verbose {
857 println!("== Early stopping achieved.");
858 println!(
859 "== Best iteration / VAL{} (val_loss={:.4})",
860 best_iter, best_val_loss
861 );
862 }
863 break;
864 }
865 }
866
867 let dist = D::from_params(¶ms);
869 let train_loss = if let Some(monitor) = &self.train_loss_monitor {
870 monitor(&dist, &y_train, sample_weight)
871 } else {
872 Scorable::total_score(&dist, &y_train, sample_weight)
873 };
874 self.evals_result.train.push(train_loss);
875
876 if self.should_print_verbose(itr) {
878 println!(
879 "[iter {}] train_loss={:.4} val_loss={:.4}",
880 itr, train_loss, val_loss
881 );
882 }
883 } else {
884 let dist = D::from_params(¶ms);
886 let train_loss = if let Some(monitor) = &self.train_loss_monitor {
887 monitor(&dist, &y_train, sample_weight)
888 } else {
889 Scorable::total_score(&dist, &y_train, sample_weight)
890 };
891 self.evals_result.train.push(train_loss);
892
893 if self.should_print_verbose(itr) {
895 let grad_norm: f64 =
897 grads.iter().map(|x| x * x).sum::<f64>().sqrt() / grads.len() as f64;
898
899 println!(
900 "[iter {}] loss={:.4} grad_norm={:.4} scale={:.4}",
901 itr, train_loss, grad_norm, scale
902 );
903 }
904 }
905 }
906
907 Ok(())
908 }
909
910 fn sample(
911 &mut self,
912 x: &Array2<f64>,
913 y: &Array1<f64>,
914 params: &Array2<f64>,
915 sample_weight: Option<&Array1<f64>>,
916 ) -> (
917 Vec<usize>,
918 Vec<usize>,
919 Array2<f64>,
920 Array1<f64>,
921 Array2<f64>,
922 Option<Array1<f64>>,
923 ) {
924 let n_samples = x.nrows();
925 let n_features = x.ncols();
926
927 let sample_size = if self.minibatch_frac >= 1.0 {
929 n_samples
930 } else {
931 ((n_samples as f64) * self.minibatch_frac) as usize
932 };
933
934 let row_idxs: Vec<usize> = if sample_size == n_samples {
938 (0..n_samples).collect()
939 } else {
940 let mut indices: Vec<usize> = (0..n_samples).collect();
941 for i in (1..indices.len()).rev() {
943 let j = self.rng.random_range(0..=i);
944 indices.swap(i, j);
945 }
946 indices.into_iter().take(sample_size).collect()
947 };
948
949 let col_size = if self.col_sample >= 1.0 {
951 n_features
952 } else if self.col_sample > 0.0 {
953 ((n_features as f64) * self.col_sample) as usize
954 } else {
955 0
956 };
957
958 let col_idxs: Vec<usize> = if col_size == n_features || col_size == 0 {
959 (0..n_features).collect()
960 } else {
961 let mut indices: Vec<usize> = (0..n_features).collect();
962 indices.shuffle(&mut self.rng);
963 indices.into_iter().take(col_size).collect()
964 };
965
966 let x_sampled = if col_size == n_features {
970 x.select(ndarray::Axis(0), &row_idxs)
972 } else {
973 let mut result = Array2::zeros((row_idxs.len(), col_idxs.len()));
975 for (new_row, &old_row) in row_idxs.iter().enumerate() {
976 for (new_col, &old_col) in col_idxs.iter().enumerate() {
977 result[[new_row, new_col]] = x[[old_row, old_col]];
978 }
979 }
980 result
981 };
982 let y_sampled = y.select(ndarray::Axis(0), &row_idxs);
983 let params_sampled = params.select(ndarray::Axis(0), &row_idxs);
984
985 let sample_weights_sampled =
987 sample_weight.map(|weights| weights.select(ndarray::Axis(0), &row_idxs));
988
989 (
990 row_idxs,
991 col_idxs,
992 x_sampled,
993 y_sampled,
994 params_sampled,
995 sample_weights_sampled,
996 )
997 }
998
999 fn get_params(&self, x: &Array2<f64>) -> Array2<f64> {
1000 self.get_params_at(x, None)
1001 }
1002
1003 fn get_params_at(&self, x: &Array2<f64>, max_iter: Option<usize>) -> Array2<f64> {
1004 if x.nrows() == 0 {
1005 return Array2::zeros((0, 0));
1006 }
1007
1008 let init_params = self
1009 .init_params
1010 .as_ref()
1011 .expect("Model has not been fitted. Call fit() before predict().");
1012 let n_params = init_params.len();
1013 let mut params = Array2::from_elem((x.nrows(), n_params), 0.0);
1014 params
1015 .outer_iter_mut()
1016 .for_each(|mut row| row.assign(init_params));
1017
1018 let n_iters = max_iter
1019 .unwrap_or(self.base_models.len())
1020 .min(self.base_models.len());
1021
1022 for (i, (learners, col_idx)) in self
1023 .base_models
1024 .iter()
1025 .zip(self.col_idxs.iter())
1026 .enumerate()
1027 .take(n_iters)
1028 {
1029 let scale = self.scalings[i];
1030
1031 let predictions_cols: Vec<Array1<f64>> = if col_idx.len() == x.ncols() {
1034 learners.iter().map(|learner| learner.predict(x)).collect()
1035 } else {
1036 let x_subset = x.select(ndarray::Axis(1), col_idx);
1037 learners
1038 .iter()
1039 .map(|learner| learner.predict(&x_subset))
1040 .collect()
1041 };
1042
1043 let predictions = to_2d_array(predictions_cols);
1044
1045 params -= &(self.learning_rate * scale * &predictions);
1046 }
1047 params
1048 }
1049
1050 pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
1052 self.get_params(x)
1053 }
1054
1055 pub fn pred_param_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
1057 self.get_params_at(x, Some(max_iter))
1058 }
1059
1060 pub fn pred_dist(&self, x: &Array2<f64>) -> D {
1061 let params = self.get_params(x);
1062 D::from_params(¶ms)
1063 }
1064
1065 pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> D {
1067 let params = self.get_params_at(x, Some(max_iter));
1068 D::from_params(¶ms)
1069 }
1070
1071 pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1072 self.pred_dist(x).predict()
1073 }
1074
1075 pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1077 self.pred_dist_at(x, max_iter).predict()
1078 }
1079
1080 pub fn staged_predict<'a>(
1082 &'a self,
1083 x: &'a Array2<f64>,
1084 ) -> impl Iterator<Item = Array1<f64>> + 'a {
1085 (1..=self.base_models.len()).map(move |i| self.predict_at(x, i))
1086 }
1087
1088 pub fn staged_pred_dist<'a>(&'a self, x: &'a Array2<f64>) -> impl Iterator<Item = D> + 'a {
1090 (1..=self.base_models.len()).map(move |i| self.pred_dist_at(x, i))
1091 }
1092
1093 pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
1095 let dist = self.pred_dist(x);
1096 Scorable::total_score(&dist, y, None)
1097 }
1098
1099 pub fn n_features(&self) -> Option<usize> {
1101 self.n_features
1102 }
1103
1104 fn should_print_verbose(&self, iteration: u32) -> bool {
1107 if !self.verbose || self.verbose_eval <= 0.0 {
1108 return false;
1109 }
1110
1111 let verbose_interval = if self.verbose_eval >= 1.0 {
1115 self.verbose_eval as u32
1116 } else {
1117 (self.n_estimators as f64 * self.verbose_eval).max(1.0) as u32
1119 };
1120
1121 verbose_interval > 0 && iteration % verbose_interval == 0
1122 }
1123
1124 fn compute_learning_rate(&self, iteration: u32, progress: f64) -> f64 {
1126 if self.adaptive_learning_rate {
1128 return self.learning_rate * (1.0 - 0.7 * progress).max(0.1);
1129 }
1130
1131 match self.lr_schedule {
1132 LearningRateSchedule::Constant => self.learning_rate,
1133 LearningRateSchedule::Linear {
1134 decay_rate,
1135 min_lr_fraction,
1136 } => self.learning_rate * (1.0 - decay_rate * progress).max(min_lr_fraction),
1137 LearningRateSchedule::Exponential { decay_rate } => {
1138 self.learning_rate * (-decay_rate * progress).exp()
1139 }
1140 LearningRateSchedule::Cosine => {
1141 self.learning_rate * 0.5 * (1.0 + (std::f64::consts::PI * progress).cos())
1142 }
1143 LearningRateSchedule::CosineWarmRestarts { restart_period } => {
1144 let period_progress = (iteration % restart_period) as f64 / restart_period as f64;
1145 self.learning_rate * 0.5 * (1.0 + (std::f64::consts::PI * period_progress).cos())
1146 }
1147 }
1148 }
1149
1150 pub fn feature_importances(&self) -> Option<Array2<f64>> {
1155 let n_features = self.n_features?;
1156 if self.base_models.is_empty() || n_features == 0 {
1157 return None;
1158 }
1159
1160 let n_params = self.init_params.as_ref()?.len();
1161 let mut importances = Array2::zeros((n_params, n_features));
1162
1163 for (iter_idx, learners) in self.base_models.iter().enumerate() {
1165 let scale = self.scalings[iter_idx].abs();
1166
1167 for (param_idx, learner) in learners.iter().enumerate() {
1168 if let Some(feature_idx) = learner.split_feature() {
1169 if feature_idx < n_features {
1170 importances[[param_idx, feature_idx]] += scale;
1171 }
1172 }
1173 }
1174 }
1175
1176 for mut row in importances.rows_mut() {
1178 let sum: f64 = row.sum();
1179 if sum > 0.0 {
1180 row.mapv_inplace(|v| v / sum);
1181 }
1182 }
1183
1184 Some(importances)
1185 }
1186
1187 pub fn calibrate_uncertainty(
1190 &mut self,
1191 x_val: &Array2<f64>,
1192 y_val: &Array1<f64>,
1193 ) -> Result<(), &'static str> {
1194 if self.base_models.is_empty() {
1195 return Err("Model must be trained before calibration");
1196 }
1197
1198 let params = self.pred_param(x_val);
1200 let dist = D::from_params(¶ms);
1201
1202 let predictions = dist.predict();
1204 let errors = y_val - &predictions;
1205
1206 let empirical_var = errors.mapv(|e| e * e).mean().unwrap_or(1.0);
1208
1209 if let Some(init_params) = self.init_params.as_mut() {
1211 if init_params.len() >= 2 {
1212 let current_var = (-init_params[1]).exp(); let target_var = empirical_var;
1215 let calibration_factor = (target_var / current_var).sqrt();
1216 init_params[1] += calibration_factor.ln();
1217 }
1218 }
1219
1220 Ok(())
1221 }
1222
1223 pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
1226 let importances = self.feature_importances()?;
1227 let mut aggregated = importances.sum_axis(ndarray::Axis(0));
1228
1229 let sum: f64 = aggregated.sum();
1230 if sum > 0.0 {
1231 aggregated.mapv_inplace(|v| v / sum);
1232 }
1233
1234 Some(aggregated)
1235 }
1236
1237 fn line_search(
1238 &self,
1239 resids: &Array2<f64>,
1240 start: &Array2<f64>,
1241 y: &Array1<f64>,
1242 sample_weight: Option<&Array1<f64>>,
1243 ) -> f64 {
1244 match self.line_search_method {
1245 LineSearchMethod::Binary => self.line_search_binary(resids, start, y, sample_weight),
1246 LineSearchMethod::GoldenSection { max_iters } => {
1247 self.line_search_golden_section(resids, start, y, sample_weight, max_iters)
1248 }
1249 }
1250 }
1251
1252 fn line_search_binary(
1254 &self,
1255 resids: &Array2<f64>,
1256 start: &Array2<f64>,
1257 y: &Array1<f64>,
1258 sample_weight: Option<&Array1<f64>>,
1259 ) -> f64 {
1260 let mut scale = 1.0;
1261 let initial_score = Scorable::total_score(&D::from_params(start), y, sample_weight);
1262
1263 loop {
1265 if scale > 256.0 {
1266 break;
1267 }
1268 let scaled_resids = resids * (scale * 2.0);
1269 let next_params = start - &scaled_resids;
1270 let score = Scorable::total_score(&D::from_params(&next_params), y, sample_weight);
1271 if score >= initial_score || !score.is_finite() {
1272 break;
1273 }
1274 scale *= 2.0;
1275 }
1276
1277 loop {
1279 let scaled_resids = resids * scale;
1280 let norm: f64 = scaled_resids
1281 .rows()
1282 .into_iter()
1283 .map(|row| row.iter().map(|x| x * x).sum::<f64>().sqrt())
1284 .sum::<f64>()
1285 / scaled_resids.nrows() as f64;
1286 if norm < self.tol {
1287 break;
1288 }
1289
1290 let next_params = start - &scaled_resids;
1291 let score = Scorable::total_score(&D::from_params(&next_params), y, sample_weight);
1292 if score < initial_score && score.is_finite() {
1293 break;
1294 }
1295 scale *= 0.5;
1296
1297 if scale < 1e-10 {
1298 break;
1299 }
1300 }
1301
1302 scale
1303 }
1304
1305 fn line_search_golden_section(
1308 &self,
1309 resids: &Array2<f64>,
1310 start: &Array2<f64>,
1311 y: &Array1<f64>,
1312 sample_weight: Option<&Array1<f64>>,
1313 max_iters: usize,
1314 ) -> f64 {
1315 let compute_score = |scale: f64| -> f64 {
1317 let scaled_resids = resids * scale;
1318 let next_params = start - &scaled_resids;
1319 Scorable::total_score(&D::from_params(&next_params), y, sample_weight)
1320 };
1321
1322 let initial_score = compute_score(0.0);
1323
1324 let mut upper = 1.0;
1326 while upper < 256.0 {
1327 let score = compute_score(upper * 2.0);
1328 if score >= initial_score || !score.is_finite() {
1329 break;
1330 }
1331 upper *= 2.0;
1332 }
1333
1334 let mut a = 0.0;
1336 let mut b = upper;
1337 let inv_phi = 1.0 / GOLDEN_RATIO;
1338 let _inv_phi2 = 1.0 / (GOLDEN_RATIO * GOLDEN_RATIO); let mut c = b - (b - a) * inv_phi;
1342 let mut d = a + (b - a) * inv_phi;
1343 let mut fc = compute_score(c);
1344 let mut fd = compute_score(d);
1345
1346 for _ in 0..max_iters {
1347 if (b - a).abs() < self.tol {
1348 break;
1349 }
1350
1351 if fc < fd {
1352 b = d;
1354 d = c;
1355 fd = fc;
1356 c = b - (b - a) * inv_phi;
1357 fc = compute_score(c);
1358 } else {
1359 a = c;
1361 c = d;
1362 fc = fd;
1363 d = a + (b - a) * inv_phi;
1364 fd = compute_score(d);
1365 }
1366 }
1367
1368 let scale = (a + b) / 2.0;
1370
1371 let final_score = compute_score(scale);
1373 if final_score < initial_score && final_score.is_finite() {
1374 scale
1375 } else {
1376 1.0
1378 }
1379 }
1380
1381 pub fn serialize(&self) -> Result<SerializedNGBoost, Box<dyn std::error::Error>> {
1383 let serialized_base_models: Vec<Vec<crate::learners::SerializableTrainedLearner>> = self
1385 .base_models
1386 .iter()
1387 .map(|learners| {
1388 learners
1389 .iter()
1390 .filter_map(|learner| learner.to_serializable())
1391 .collect()
1392 })
1393 .collect();
1394
1395 Ok(SerializedNGBoost {
1396 n_estimators: self.n_estimators,
1397 learning_rate: self.learning_rate,
1398 natural_gradient: self.natural_gradient,
1399 minibatch_frac: self.minibatch_frac,
1400 col_sample: self.col_sample,
1401 verbose: self.verbose,
1402 verbose_eval: self.verbose_eval,
1403 tol: self.tol,
1404 early_stopping_rounds: self.early_stopping_rounds,
1405 validation_fraction: self.validation_fraction,
1406 init_params: self.init_params.as_ref().map(|p| p.to_vec()),
1407 scalings: self.scalings.clone(),
1408 col_idxs: self.col_idxs.clone(),
1409 best_val_loss_itr: self.best_val_loss_itr,
1410 base_models: serialized_base_models,
1411 lr_schedule: self.lr_schedule,
1412 tikhonov_reg: self.tikhonov_reg,
1413 line_search_method: self.line_search_method,
1414 n_features: self.n_features,
1415 random_state: self.random_state,
1416 })
1417 }
1418
1419 pub fn deserialize(
1421 serialized: SerializedNGBoost,
1422 base_learner: B,
1423 ) -> Result<Self, Box<dyn std::error::Error>>
1424 where
1425 D: Distribution + Scorable<S> + Clone,
1426 S: Score,
1427 B: BaseLearner + Clone,
1428 {
1429 let mut model = Self::with_options_seeded(
1430 serialized.n_estimators,
1431 serialized.learning_rate,
1432 base_learner,
1433 serialized.natural_gradient,
1434 serialized.minibatch_frac,
1435 serialized.col_sample,
1436 serialized.verbose,
1437 serialized.verbose_eval,
1438 serialized.tol,
1439 serialized.early_stopping_rounds,
1440 serialized.validation_fraction,
1441 false, serialized.random_state,
1443 );
1444
1445 if let Some(init_params) = serialized.init_params {
1447 model.init_params = Some(Array1::from(init_params));
1448 }
1449 model.scalings = serialized.scalings;
1450 model.col_idxs = serialized.col_idxs;
1451 model.best_val_loss_itr = serialized.best_val_loss_itr;
1452
1453 model.lr_schedule = serialized.lr_schedule;
1455 model.tikhonov_reg = serialized.tikhonov_reg;
1456 model.line_search_method = serialized.line_search_method;
1457 model.n_features = serialized.n_features;
1458
1459 model.base_models = serialized
1461 .base_models
1462 .into_iter()
1463 .map(|learners| learners.into_iter().map(|l| l.to_trait_object()).collect())
1464 .collect();
1465
1466 Ok(model)
1467 }
1468}
1469
1470#[derive(serde::Serialize, serde::Deserialize)]
1472pub struct SerializedNGBoost {
1473 pub n_estimators: u32,
1474 pub learning_rate: f64,
1475 pub natural_gradient: bool,
1476 pub minibatch_frac: f64,
1477 pub col_sample: f64,
1478 pub verbose: bool,
1479 pub verbose_eval: f64,
1481 pub tol: f64,
1482 pub early_stopping_rounds: Option<u32>,
1483 pub validation_fraction: f64,
1484 pub init_params: Option<Vec<f64>>,
1485 pub scalings: Vec<f64>,
1486 pub col_idxs: Vec<Vec<usize>>,
1487 pub best_val_loss_itr: Option<usize>,
1488 pub base_models: Vec<Vec<crate::learners::SerializableTrainedLearner>>,
1490 #[serde(default)]
1492 pub lr_schedule: LearningRateSchedule,
1493 #[serde(default)]
1495 pub tikhonov_reg: f64,
1496 #[serde(default)]
1498 pub line_search_method: LineSearchMethod,
1499 #[serde(default)]
1501 pub n_features: Option<usize>,
1502 #[serde(default)]
1504 pub random_state: Option<u64>,
1505}
1506
1507fn to_2d_array(cols: Vec<Array1<f64>>) -> Array2<f64> {
1508 if cols.is_empty() {
1509 return Array2::zeros((0, 0));
1510 }
1511 let nrows = cols[0].len();
1512 let ncols = cols.len();
1513 let mut arr = Array2::zeros((nrows, ncols));
1514 for (j, col) in cols.iter().enumerate() {
1515 arr.column_mut(j).assign(col);
1516 }
1517 arr
1518}
1519
1520pub struct NGBRegressor {
1522 model: NGBoost<Normal, LogScore, DecisionTreeLearner>,
1523}
1524
1525pub struct NGBClassifier {
1526 model: NGBoost<Bernoulli, LogScore, DecisionTreeLearner>,
1527}
1528
1529impl NGBRegressor {
1530 pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
1531 Self {
1532 model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
1533 }
1534 }
1535
1536 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1537 self.model.fit(x, y)
1538 }
1539
1540 pub fn fit_with_validation(
1541 &mut self,
1542 x: &Array2<f64>,
1543 y: &Array1<f64>,
1544 x_val: Option<&Array2<f64>>,
1545 y_val: Option<&Array1<f64>>,
1546 ) -> Result<(), &'static str> {
1547 self.model
1548 .fit_with_validation(x, y, x_val, y_val, None, None)
1549 }
1550
1551 pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1558 self.model.partial_fit(x, y)
1559 }
1560
1561 pub fn partial_fit_with_validation(
1563 &mut self,
1564 x: &Array2<f64>,
1565 y: &Array1<f64>,
1566 x_val: Option<&Array2<f64>>,
1567 y_val: Option<&Array1<f64>>,
1568 ) -> Result<(), &'static str> {
1569 self.model
1570 .partial_fit_with_validation(x, y, x_val, y_val, None, None)
1571 }
1572
1573 pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1574 self.model.predict(x)
1575 }
1576
1577 pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1579 self.model.predict_at(x, max_iter)
1580 }
1581
1582 pub fn staged_predict<'a>(
1584 &'a self,
1585 x: &'a Array2<f64>,
1586 ) -> impl Iterator<Item = Array1<f64>> + 'a {
1587 self.model.staged_predict(x)
1588 }
1589
1590 pub fn pred_dist(&self, x: &Array2<f64>) -> Normal {
1591 self.model.pred_dist(x)
1592 }
1593
1594 pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Normal {
1596 self.model.pred_dist_at(x, max_iter)
1597 }
1598
1599 pub fn staged_pred_dist<'a>(&'a self, x: &'a Array2<f64>) -> impl Iterator<Item = Normal> + 'a {
1601 self.model.staged_pred_dist(x)
1602 }
1603
1604 pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
1606 self.model.pred_param(x)
1607 }
1608
1609 pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
1611 self.model.score(x, y)
1612 }
1613
1614 pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
1616 where
1617 F: Fn(&Normal, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1618 {
1619 self.model.set_train_loss_monitor(Box::new(monitor));
1620 }
1621
1622 pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
1624 where
1625 F: Fn(&Normal, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1626 {
1627 self.model.set_val_loss_monitor(Box::new(monitor));
1628 }
1629
1630 pub fn with_options(
1632 n_estimators: u32,
1633 learning_rate: f64,
1634 natural_gradient: bool,
1635 minibatch_frac: f64,
1636 col_sample: f64,
1637 verbose: bool,
1638 verbose_eval: f64,
1639 tol: f64,
1640 early_stopping_rounds: Option<u32>,
1641 validation_fraction: f64,
1642 adaptive_learning_rate: bool,
1643 ) -> Self {
1644 Self {
1645 model: NGBoost::with_options(
1646 n_estimators,
1647 learning_rate,
1648 default_tree_learner(),
1649 natural_gradient,
1650 minibatch_frac,
1651 col_sample,
1652 verbose,
1653 verbose_eval,
1654 tol,
1655 early_stopping_rounds,
1656 validation_fraction,
1657 adaptive_learning_rate,
1658 ),
1659 }
1660 }
1661
1662 pub fn with_options_compat(
1664 n_estimators: u32,
1665 learning_rate: f64,
1666 natural_gradient: bool,
1667 minibatch_frac: f64,
1668 col_sample: f64,
1669 verbose: bool,
1670 verbose_eval: f64,
1671 tol: f64,
1672 early_stopping_rounds: Option<u32>,
1673 validation_fraction: f64,
1674 ) -> Self {
1675 Self::with_options(
1676 n_estimators,
1677 learning_rate,
1678 natural_gradient,
1679 minibatch_frac,
1680 col_sample,
1681 verbose,
1682 verbose_eval,
1683 tol,
1684 early_stopping_rounds,
1685 validation_fraction,
1686 false, )
1688 }
1689
1690 pub fn set_adaptive_learning_rate(&mut self, enabled: bool) {
1692 self.model.adaptive_learning_rate = enabled;
1693 }
1694
1695 pub fn calibrate_uncertainty(
1698 &mut self,
1699 x_val: &Array2<f64>,
1700 y_val: &Array1<f64>,
1701 ) -> Result<(), &'static str> {
1702 self.model.calibrate_uncertainty(x_val, y_val)
1703 }
1704
1705 pub fn n_estimators(&self) -> u32 {
1707 self.model.n_estimators
1708 }
1709
1710 pub fn learning_rate(&self) -> f64 {
1712 self.model.learning_rate
1713 }
1714
1715 pub fn natural_gradient(&self) -> bool {
1717 self.model.natural_gradient
1718 }
1719
1720 pub fn minibatch_frac(&self) -> f64 {
1722 self.model.minibatch_frac
1723 }
1724
1725 pub fn col_sample(&self) -> f64 {
1727 self.model.col_sample
1728 }
1729
1730 pub fn best_val_loss_itr(&self) -> Option<usize> {
1732 self.model.best_val_loss_itr
1733 }
1734
1735 pub fn early_stopping_rounds(&self) -> Option<u32> {
1737 self.model.early_stopping_rounds
1738 }
1739
1740 pub fn validation_fraction(&self) -> f64 {
1742 self.model.validation_fraction
1743 }
1744
1745 pub fn n_features(&self) -> Option<usize> {
1747 self.model.n_features()
1748 }
1749
1750 pub fn feature_importances(&self) -> Option<Array2<f64>> {
1753 self.model.feature_importances()
1754 }
1755
1756 pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
1759 self.model.feature_importances_aggregated()
1760 }
1761
1762 pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1764 let serialized = self.model.serialize()?;
1765 let encoded = bincode::serialize(&serialized)?;
1766 std::fs::write(path, encoded)?;
1767 Ok(())
1768 }
1769
1770 pub fn load_model(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
1772 let encoded = std::fs::read(path)?;
1773 let serialized: SerializedNGBoost = bincode::deserialize(&encoded)?;
1774 let model = NGBoost::<Normal, LogScore, DecisionTreeLearner>::deserialize(
1775 serialized,
1776 default_tree_learner(),
1777 )?;
1778 Ok(Self { model })
1779 }
1780
1781 pub fn evals_result(&self) -> &EvalsResult {
1783 self.model.evals_result()
1784 }
1785
1786 pub fn set_random_state(&mut self, seed: u64) {
1788 self.model.set_random_state(seed);
1789 }
1790
1791 pub fn random_state(&self) -> Option<u64> {
1793 self.model.random_state()
1794 }
1795
1796 pub fn fit_with_weights(
1798 &mut self,
1799 x: &Array2<f64>,
1800 y: &Array1<f64>,
1801 sample_weight: Option<&Array1<f64>>,
1802 ) -> Result<(), &'static str> {
1803 self.model
1804 .fit_with_validation(x, y, None, None, sample_weight, None)
1805 }
1806
1807 pub fn fit_with_weights_and_validation(
1809 &mut self,
1810 x: &Array2<f64>,
1811 y: &Array1<f64>,
1812 x_val: Option<&Array2<f64>>,
1813 y_val: Option<&Array1<f64>>,
1814 sample_weight: Option<&Array1<f64>>,
1815 val_sample_weight: Option<&Array1<f64>>,
1816 ) -> Result<(), &'static str> {
1817 self.model
1818 .fit_with_validation(x, y, x_val, y_val, sample_weight, val_sample_weight)
1819 }
1820
1821 pub fn get_params(&self) -> NGBoostParams {
1823 NGBoostParams {
1824 n_estimators: self.model.n_estimators,
1825 learning_rate: self.model.learning_rate,
1826 natural_gradient: self.model.natural_gradient,
1827 minibatch_frac: self.model.minibatch_frac,
1828 col_sample: self.model.col_sample,
1829 verbose: self.model.verbose,
1830 verbose_eval: self.model.verbose_eval,
1831 tol: self.model.tol,
1832 early_stopping_rounds: self.model.early_stopping_rounds,
1833 validation_fraction: self.model.validation_fraction,
1834 random_state: self.model.random_state(),
1835 lr_schedule: self.model.lr_schedule,
1836 tikhonov_reg: self.model.tikhonov_reg,
1837 line_search_method: self.model.line_search_method,
1838 }
1839 }
1840
1841 pub fn set_params(&mut self, params: NGBoostParams) {
1844 self.model.n_estimators = params.n_estimators;
1845 self.model.learning_rate = params.learning_rate;
1846 self.model.natural_gradient = params.natural_gradient;
1847 self.model.minibatch_frac = params.minibatch_frac;
1848 self.model.col_sample = params.col_sample;
1849 self.model.verbose = params.verbose;
1850 self.model.verbose_eval = params.verbose_eval;
1851 self.model.tol = params.tol;
1852 self.model.early_stopping_rounds = params.early_stopping_rounds;
1853 self.model.validation_fraction = params.validation_fraction;
1854 self.model.lr_schedule = params.lr_schedule;
1855 self.model.tikhonov_reg = params.tikhonov_reg;
1856 self.model.line_search_method = params.line_search_method;
1857 if let Some(seed) = params.random_state {
1858 self.model.set_random_state(seed);
1859 }
1860 }
1861}
1862
1863pub struct NGBMultiClassifier<const K: usize> {
1878 model: NGBoost<Categorical<K>, LogScore, DecisionTreeLearner>,
1879}
1880
1881impl<const K: usize> NGBMultiClassifier<K> {
1882 pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
1883 Self {
1884 model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
1885 }
1886 }
1887
1888 pub fn with_options(
1889 n_estimators: u32,
1890 learning_rate: f64,
1891 natural_gradient: bool,
1892 minibatch_frac: f64,
1893 col_sample: f64,
1894 verbose: bool,
1895 verbose_eval: f64,
1896 tol: f64,
1897 early_stopping_rounds: Option<u32>,
1898 validation_fraction: f64,
1899 adaptive_learning_rate: bool,
1900 ) -> Self {
1901 Self {
1902 model: NGBoost::with_options(
1903 n_estimators,
1904 learning_rate,
1905 default_tree_learner(),
1906 natural_gradient,
1907 minibatch_frac,
1908 col_sample,
1909 verbose,
1910 verbose_eval,
1911 tol,
1912 early_stopping_rounds,
1913 validation_fraction,
1914 adaptive_learning_rate,
1915 ),
1916 }
1917 }
1918
1919 pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
1921 where
1922 F: Fn(&Categorical<K>, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1923 {
1924 self.model.set_train_loss_monitor(Box::new(monitor));
1925 }
1926
1927 pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
1929 where
1930 F: Fn(&Categorical<K>, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1931 {
1932 self.model.set_val_loss_monitor(Box::new(monitor));
1933 }
1934
1935 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1936 self.model.fit(x, y)
1937 }
1938
1939 pub fn fit_with_validation(
1940 &mut self,
1941 x: &Array2<f64>,
1942 y: &Array1<f64>,
1943 x_val: Option<&Array2<f64>>,
1944 y_val: Option<&Array1<f64>>,
1945 ) -> Result<(), &'static str> {
1946 self.model
1947 .fit_with_validation(x, y, x_val, y_val, None, None)
1948 }
1949
1950 pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1952 self.model.partial_fit(x, y)
1953 }
1954
1955 pub fn partial_fit_with_validation(
1957 &mut self,
1958 x: &Array2<f64>,
1959 y: &Array1<f64>,
1960 x_val: Option<&Array2<f64>>,
1961 y_val: Option<&Array1<f64>>,
1962 ) -> Result<(), &'static str> {
1963 self.model
1964 .partial_fit_with_validation(x, y, x_val, y_val, None, None)
1965 }
1966
1967 pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1969 self.model.predict(x)
1970 }
1971
1972 pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1974 self.model.predict_at(x, max_iter)
1975 }
1976
1977 pub fn staged_predict<'a>(
1979 &'a self,
1980 x: &'a Array2<f64>,
1981 ) -> impl Iterator<Item = Array1<f64>> + 'a {
1982 self.model.staged_predict(x)
1983 }
1984
1985 pub fn predict_proba(&self, x: &Array2<f64>) -> Array2<f64> {
1988 let dist = self.model.pred_dist(x);
1989 dist.class_probs()
1990 }
1991
1992 pub fn predict_proba_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
1994 let dist = self.model.pred_dist_at(x, max_iter);
1995 dist.class_probs()
1996 }
1997
1998 pub fn staged_predict_proba<'a>(
2000 &'a self,
2001 x: &'a Array2<f64>,
2002 ) -> impl Iterator<Item = Array2<f64>> + 'a {
2003 (1..=self.model.base_models.len()).map(move |i| self.predict_proba_at(x, i))
2004 }
2005
2006 pub fn pred_dist(&self, x: &Array2<f64>) -> Categorical<K> {
2008 self.model.pred_dist(x)
2009 }
2010
2011 pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Categorical<K> {
2013 self.model.pred_dist_at(x, max_iter)
2014 }
2015
2016 pub fn staged_pred_dist<'a>(
2018 &'a self,
2019 x: &'a Array2<f64>,
2020 ) -> impl Iterator<Item = Categorical<K>> + 'a {
2021 self.model.staged_pred_dist(x)
2022 }
2023
2024 pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
2026 self.model.pred_param(x)
2027 }
2028
2029 pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
2031 self.model.score(x, y)
2032 }
2033
2034 pub fn n_estimators(&self) -> u32 {
2036 self.model.n_estimators
2037 }
2038
2039 pub fn learning_rate(&self) -> f64 {
2041 self.model.learning_rate
2042 }
2043
2044 pub fn natural_gradient(&self) -> bool {
2046 self.model.natural_gradient
2047 }
2048
2049 pub fn minibatch_frac(&self) -> f64 {
2051 self.model.minibatch_frac
2052 }
2053
2054 pub fn col_sample(&self) -> f64 {
2056 self.model.col_sample
2057 }
2058
2059 pub fn best_val_loss_itr(&self) -> Option<usize> {
2061 self.model.best_val_loss_itr
2062 }
2063
2064 pub fn early_stopping_rounds(&self) -> Option<u32> {
2066 self.model.early_stopping_rounds
2067 }
2068
2069 pub fn validation_fraction(&self) -> f64 {
2071 self.model.validation_fraction
2072 }
2073
2074 pub fn n_features(&self) -> Option<usize> {
2076 self.model.n_features()
2077 }
2078
2079 pub fn feature_importances(&self) -> Option<Array2<f64>> {
2081 self.model.feature_importances()
2082 }
2083
2084 pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
2086 self.model.feature_importances_aggregated()
2087 }
2088
2089 pub fn evals_result(&self) -> &EvalsResult {
2091 self.model.evals_result()
2092 }
2093
2094 pub fn set_random_state(&mut self, seed: u64) {
2096 self.model.set_random_state(seed);
2097 }
2098
2099 pub fn random_state(&self) -> Option<u64> {
2101 self.model.random_state()
2102 }
2103
2104 pub fn fit_with_weights(
2106 &mut self,
2107 x: &Array2<f64>,
2108 y: &Array1<f64>,
2109 sample_weight: Option<&Array1<f64>>,
2110 ) -> Result<(), &'static str> {
2111 self.model
2112 .fit_with_validation(x, y, None, None, sample_weight, None)
2113 }
2114
2115 pub fn fit_with_weights_and_validation(
2117 &mut self,
2118 x: &Array2<f64>,
2119 y: &Array1<f64>,
2120 x_val: Option<&Array2<f64>>,
2121 y_val: Option<&Array1<f64>>,
2122 sample_weight: Option<&Array1<f64>>,
2123 val_sample_weight: Option<&Array1<f64>>,
2124 ) -> Result<(), &'static str> {
2125 self.model
2126 .fit_with_validation(x, y, x_val, y_val, sample_weight, val_sample_weight)
2127 }
2128
2129 pub fn get_params(&self) -> NGBoostParams {
2131 NGBoostParams {
2132 n_estimators: self.model.n_estimators,
2133 learning_rate: self.model.learning_rate,
2134 natural_gradient: self.model.natural_gradient,
2135 minibatch_frac: self.model.minibatch_frac,
2136 col_sample: self.model.col_sample,
2137 verbose: self.model.verbose,
2138 verbose_eval: self.model.verbose_eval,
2139 tol: self.model.tol,
2140 early_stopping_rounds: self.model.early_stopping_rounds,
2141 validation_fraction: self.model.validation_fraction,
2142 random_state: self.model.random_state(),
2143 lr_schedule: self.model.lr_schedule,
2144 tikhonov_reg: self.model.tikhonov_reg,
2145 line_search_method: self.model.line_search_method,
2146 }
2147 }
2148
2149 pub fn set_params(&mut self, params: NGBoostParams) {
2151 self.model.n_estimators = params.n_estimators;
2152 self.model.learning_rate = params.learning_rate;
2153 self.model.natural_gradient = params.natural_gradient;
2154 self.model.minibatch_frac = params.minibatch_frac;
2155 self.model.col_sample = params.col_sample;
2156 self.model.verbose = params.verbose;
2157 self.model.verbose_eval = params.verbose_eval;
2158 self.model.tol = params.tol;
2159 self.model.early_stopping_rounds = params.early_stopping_rounds;
2160 self.model.validation_fraction = params.validation_fraction;
2161 self.model.lr_schedule = params.lr_schedule;
2162 self.model.tikhonov_reg = params.tikhonov_reg;
2163 self.model.line_search_method = params.line_search_method;
2164 if let Some(seed) = params.random_state {
2165 self.model.set_random_state(seed);
2166 }
2167 }
2168
2169 pub fn n_classes(&self) -> usize {
2171 K
2172 }
2173}
2174
2175pub type NGBMultiClassifier3 = NGBMultiClassifier<3>;
2177
2178pub type NGBMultiClassifier4 = NGBMultiClassifier<4>;
2180
2181pub type NGBMultiClassifier5 = NGBMultiClassifier<5>;
2183
2184pub type NGBMultiClassifier10 = NGBMultiClassifier<10>;
2186
2187impl NGBClassifier {
2188 pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
2189 Self {
2190 model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
2191 }
2192 }
2193
2194 pub fn with_options(
2195 n_estimators: u32,
2196 learning_rate: f64,
2197 natural_gradient: bool,
2198 minibatch_frac: f64,
2199 col_sample: f64,
2200 verbose: bool,
2201 verbose_eval: f64,
2202 tol: f64,
2203 early_stopping_rounds: Option<u32>,
2204 validation_fraction: f64,
2205 adaptive_learning_rate: bool,
2206 ) -> Self {
2207 Self {
2208 model: NGBoost::with_options(
2209 n_estimators,
2210 learning_rate,
2211 default_tree_learner(),
2212 natural_gradient,
2213 minibatch_frac,
2214 col_sample,
2215 verbose,
2216 verbose_eval,
2217 tol,
2218 early_stopping_rounds,
2219 validation_fraction,
2220 adaptive_learning_rate,
2221 ),
2222 }
2223 }
2224
2225 pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
2227 where
2228 F: Fn(&Bernoulli, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
2229 {
2230 self.model.set_train_loss_monitor(Box::new(monitor));
2231 }
2232
2233 pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
2235 where
2236 F: Fn(&Bernoulli, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
2237 {
2238 self.model.set_val_loss_monitor(Box::new(monitor));
2239 }
2240
2241 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
2242 self.model.fit(x, y)
2243 }
2244
2245 pub fn fit_with_validation(
2246 &mut self,
2247 x: &Array2<f64>,
2248 y: &Array1<f64>,
2249 x_val: Option<&Array2<f64>>,
2250 y_val: Option<&Array1<f64>>,
2251 ) -> Result<(), &'static str> {
2252 self.model
2253 .fit_with_validation(x, y, x_val, y_val, None, None)
2254 }
2255
2256 pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
2263 self.model.partial_fit(x, y)
2264 }
2265
2266 pub fn partial_fit_with_validation(
2268 &mut self,
2269 x: &Array2<f64>,
2270 y: &Array1<f64>,
2271 x_val: Option<&Array2<f64>>,
2272 y_val: Option<&Array1<f64>>,
2273 ) -> Result<(), &'static str> {
2274 self.model
2275 .partial_fit_with_validation(x, y, x_val, y_val, None, None)
2276 }
2277
2278 pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
2279 self.model.predict(x)
2280 }
2281
2282 pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
2284 self.model.predict_at(x, max_iter)
2285 }
2286
2287 pub fn staged_predict<'a>(
2289 &'a self,
2290 x: &'a Array2<f64>,
2291 ) -> impl Iterator<Item = Array1<f64>> + 'a {
2292 self.model.staged_predict(x)
2293 }
2294
2295 pub fn predict_proba(&self, x: &Array2<f64>) -> Array2<f64> {
2296 let dist = self.model.pred_dist(x);
2297 dist.class_probs()
2298 }
2299
2300 pub fn predict_proba_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
2302 let dist = self.model.pred_dist_at(x, max_iter);
2303 dist.class_probs()
2304 }
2305
2306 pub fn staged_predict_proba<'a>(
2308 &'a self,
2309 x: &'a Array2<f64>,
2310 ) -> impl Iterator<Item = Array2<f64>> + 'a {
2311 (1..=self.model.base_models.len()).map(move |i| self.predict_proba_at(x, i))
2312 }
2313
2314 pub fn pred_dist(&self, x: &Array2<f64>) -> Bernoulli {
2315 self.model.pred_dist(x)
2316 }
2317
2318 pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Bernoulli {
2320 self.model.pred_dist_at(x, max_iter)
2321 }
2322
2323 pub fn staged_pred_dist<'a>(
2325 &'a self,
2326 x: &'a Array2<f64>,
2327 ) -> impl Iterator<Item = Bernoulli> + 'a {
2328 self.model.staged_pred_dist(x)
2329 }
2330
2331 pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
2333 self.model.pred_param(x)
2334 }
2335
2336 pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
2338 self.model.score(x, y)
2339 }
2340
2341 pub fn n_estimators(&self) -> u32 {
2343 self.model.n_estimators
2344 }
2345
2346 pub fn learning_rate(&self) -> f64 {
2348 self.model.learning_rate
2349 }
2350
2351 pub fn natural_gradient(&self) -> bool {
2353 self.model.natural_gradient
2354 }
2355
2356 pub fn minibatch_frac(&self) -> f64 {
2358 self.model.minibatch_frac
2359 }
2360
2361 pub fn col_sample(&self) -> f64 {
2363 self.model.col_sample
2364 }
2365
2366 pub fn best_val_loss_itr(&self) -> Option<usize> {
2368 self.model.best_val_loss_itr
2369 }
2370
2371 pub fn early_stopping_rounds(&self) -> Option<u32> {
2373 self.model.early_stopping_rounds
2374 }
2375
2376 pub fn validation_fraction(&self) -> f64 {
2378 self.model.validation_fraction
2379 }
2380
2381 pub fn n_features(&self) -> Option<usize> {
2383 self.model.n_features()
2384 }
2385
2386 pub fn feature_importances(&self) -> Option<Array2<f64>> {
2389 self.model.feature_importances()
2390 }
2391
2392 pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
2395 self.model.feature_importances_aggregated()
2396 }
2397
2398 pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
2400 let serialized = self.model.serialize()?;
2401 let encoded = bincode::serialize(&serialized)?;
2402 std::fs::write(path, encoded)?;
2403 Ok(())
2404 }
2405
2406 pub fn load_model(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
2408 let encoded = std::fs::read(path)?;
2409 let serialized: SerializedNGBoost = bincode::deserialize(&encoded)?;
2410 let model = NGBoost::<Bernoulli, LogScore, DecisionTreeLearner>::deserialize(
2411 serialized,
2412 default_tree_learner(),
2413 )?;
2414 Ok(Self { model })
2415 }
2416
2417 pub fn calibrate_uncertainty(
2422 &mut self,
2423 _x_val: &Array2<f64>,
2424 _y_val: &Array1<f64>,
2425 ) -> Result<(), &'static str> {
2426 Ok(())
2430 }
2431
2432 pub fn evals_result(&self) -> &EvalsResult {
2434 self.model.evals_result()
2435 }
2436
2437 pub fn set_random_state(&mut self, seed: u64) {
2439 self.model.set_random_state(seed);
2440 }
2441
2442 pub fn random_state(&self) -> Option<u64> {
2444 self.model.random_state()
2445 }
2446
2447 pub fn fit_with_weights(
2449 &mut self,
2450 x: &Array2<f64>,
2451 y: &Array1<f64>,
2452 sample_weight: Option<&Array1<f64>>,
2453 ) -> Result<(), &'static str> {
2454 self.model
2455 .fit_with_validation(x, y, None, None, sample_weight, None)
2456 }
2457
2458 pub fn fit_with_weights_and_validation(
2460 &mut self,
2461 x: &Array2<f64>,
2462 y: &Array1<f64>,
2463 x_val: Option<&Array2<f64>>,
2464 y_val: Option<&Array1<f64>>,
2465 sample_weight: Option<&Array1<f64>>,
2466 val_sample_weight: Option<&Array1<f64>>,
2467 ) -> Result<(), &'static str> {
2468 self.model
2469 .fit_with_validation(x, y, x_val, y_val, sample_weight, val_sample_weight)
2470 }
2471
2472 pub fn get_params(&self) -> NGBoostParams {
2474 NGBoostParams {
2475 n_estimators: self.model.n_estimators,
2476 learning_rate: self.model.learning_rate,
2477 natural_gradient: self.model.natural_gradient,
2478 minibatch_frac: self.model.minibatch_frac,
2479 col_sample: self.model.col_sample,
2480 verbose: self.model.verbose,
2481 verbose_eval: self.model.verbose_eval,
2482 tol: self.model.tol,
2483 early_stopping_rounds: self.model.early_stopping_rounds,
2484 validation_fraction: self.model.validation_fraction,
2485 random_state: self.model.random_state(),
2486 lr_schedule: self.model.lr_schedule,
2487 tikhonov_reg: self.model.tikhonov_reg,
2488 line_search_method: self.model.line_search_method,
2489 }
2490 }
2491
2492 pub fn set_params(&mut self, params: NGBoostParams) {
2495 self.model.n_estimators = params.n_estimators;
2496 self.model.learning_rate = params.learning_rate;
2497 self.model.natural_gradient = params.natural_gradient;
2498 self.model.minibatch_frac = params.minibatch_frac;
2499 self.model.col_sample = params.col_sample;
2500 self.model.verbose = params.verbose;
2501 self.model.verbose_eval = params.verbose_eval;
2502 self.model.tol = params.tol;
2503 self.model.early_stopping_rounds = params.early_stopping_rounds;
2504 self.model.validation_fraction = params.validation_fraction;
2505 self.model.lr_schedule = params.lr_schedule;
2506 self.model.tikhonov_reg = params.tikhonov_reg;
2507 self.model.line_search_method = params.line_search_method;
2508 if let Some(seed) = params.random_state {
2509 self.model.set_random_state(seed);
2510 }
2511 }
2512}