1pub type LossMonitor<D> = Box<dyn Fn(&D, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync>;
3
4use crate::dist::categorical::Bernoulli;
5use crate::dist::normal::Normal;
6use crate::dist::{ClassificationDistn, Distribution};
7use crate::learners::{default_tree_learner, BaseLearner, DecisionTreeLearner, TrainedBaseLearner};
8use crate::scores::{LogScore, Scorable, Score};
9use ndarray::{Array1, Array2};
10use rand::prelude::*;
11use rand::rng;
12use std::marker::PhantomData;
13
14#[cfg(feature = "parallel")]
15use rayon::prelude::*;
16
17#[derive(Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize)]
19pub enum LearningRateSchedule {
20 #[default]
22 Constant,
23 Linear {
26 decay_rate: f64,
27 min_lr_fraction: f64,
28 },
29 Exponential { decay_rate: f64 },
31 Cosine,
34 CosineWarmRestarts { restart_period: u32 },
37}
38
39#[derive(Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize)]
41pub enum LineSearchMethod {
42 #[default]
45 Binary,
46 GoldenSection {
50 max_iters: usize,
52 },
53}
54
55const GOLDEN_RATIO: f64 = 1.618033988749895;
57
58pub struct NGBoost<D, S, B>
59where
60 D: Distribution + Scorable<S> + Clone,
61 S: Score,
62 B: BaseLearner + Clone,
63{
64 pub n_estimators: u32,
66 pub learning_rate: f64,
67 pub natural_gradient: bool,
68 pub minibatch_frac: f64,
69 pub col_sample: f64,
70 pub verbose: bool,
71 pub verbose_eval: u32,
72 pub tol: f64,
73 pub early_stopping_rounds: Option<u32>,
74 pub validation_fraction: f64,
75 pub adaptive_learning_rate: bool, pub lr_schedule: LearningRateSchedule,
78 pub tikhonov_reg: f64,
82 pub line_search_method: LineSearchMethod,
84
85 base_learner: B,
87
88 pub base_models: Vec<Vec<Box<dyn TrainedBaseLearner>>>,
90 pub scalings: Vec<f64>,
91 pub init_params: Option<Array1<f64>>,
92 pub col_idxs: Vec<Vec<usize>>,
93 train_loss_monitor: Option<LossMonitor<D>>,
94 val_loss_monitor: Option<LossMonitor<D>>,
95 best_val_loss_itr: Option<usize>,
96 n_features: Option<usize>,
97
98 rng: ThreadRng,
100
101 _dist: PhantomData<D>,
103 _score: PhantomData<S>,
104}
105
106impl<D, S, B> NGBoost<D, S, B>
107where
108 D: Distribution + Scorable<S> + Clone,
109 S: Score,
110 B: BaseLearner + Clone,
111{
112 pub fn new(n_estimators: u32, learning_rate: f64, base_learner: B) -> Self {
113 NGBoost {
114 n_estimators,
115 learning_rate,
116 natural_gradient: true,
117 minibatch_frac: 1.0,
118 col_sample: 1.0,
119 verbose: false,
120 verbose_eval: 100,
121 tol: 1e-4,
122 early_stopping_rounds: None,
123 validation_fraction: 0.1,
124 adaptive_learning_rate: false,
125 lr_schedule: LearningRateSchedule::Constant,
126 tikhonov_reg: 0.0,
127 line_search_method: LineSearchMethod::Binary,
128 base_learner,
129 base_models: Vec::new(),
130 scalings: Vec::new(),
131 init_params: None,
132 col_idxs: Vec::new(),
133 train_loss_monitor: None,
134 val_loss_monitor: None,
135 best_val_loss_itr: None,
136 n_features: None,
137 rng: rng(),
138 _dist: PhantomData,
139 _score: PhantomData,
140 }
141 }
142
143 pub fn with_options(
144 n_estimators: u32,
145 learning_rate: f64,
146 base_learner: B,
147 natural_gradient: bool,
148 minibatch_frac: f64,
149 col_sample: f64,
150 verbose: bool,
151 verbose_eval: u32,
152 tol: f64,
153 early_stopping_rounds: Option<u32>,
154 validation_fraction: f64,
155 adaptive_learning_rate: bool,
156 ) -> Self {
157 NGBoost {
158 n_estimators,
159 learning_rate,
160 natural_gradient,
161 minibatch_frac,
162 col_sample,
163 verbose,
164 verbose_eval,
165 tol,
166 early_stopping_rounds,
167 validation_fraction,
168 adaptive_learning_rate,
169 lr_schedule: LearningRateSchedule::Constant,
170 tikhonov_reg: 0.0,
171 line_search_method: LineSearchMethod::Binary,
172 base_learner,
173 base_models: Vec::new(),
174 scalings: Vec::new(),
175 init_params: None,
176 col_idxs: Vec::new(),
177 train_loss_monitor: None,
178 val_loss_monitor: None,
179 best_val_loss_itr: None,
180 n_features: None,
181 rng: rng(),
182 _dist: PhantomData,
183 _score: PhantomData,
184 }
185 }
186
187 #[allow(clippy::too_many_arguments)]
189 pub fn with_advanced_options(
190 n_estimators: u32,
191 learning_rate: f64,
192 base_learner: B,
193 natural_gradient: bool,
194 minibatch_frac: f64,
195 col_sample: f64,
196 verbose: bool,
197 verbose_eval: u32,
198 tol: f64,
199 early_stopping_rounds: Option<u32>,
200 validation_fraction: f64,
201 lr_schedule: LearningRateSchedule,
202 tikhonov_reg: f64,
203 line_search_method: LineSearchMethod,
204 ) -> Self {
205 NGBoost {
206 n_estimators,
207 learning_rate,
208 natural_gradient,
209 minibatch_frac,
210 col_sample,
211 verbose,
212 verbose_eval,
213 tol,
214 early_stopping_rounds,
215 validation_fraction,
216 adaptive_learning_rate: false,
217 lr_schedule,
218 tikhonov_reg,
219 line_search_method,
220 base_learner,
221 base_models: Vec::new(),
222 scalings: Vec::new(),
223 init_params: None,
224 col_idxs: Vec::new(),
225 train_loss_monitor: None,
226 val_loss_monitor: None,
227 best_val_loss_itr: None,
228 n_features: None,
229 rng: rng(),
230 _dist: PhantomData,
231 _score: PhantomData,
232 }
233 }
234
235 pub fn set_train_loss_monitor(&mut self, monitor: LossMonitor<D>) {
237 self.train_loss_monitor = Some(monitor);
238 }
239
240 pub fn set_val_loss_monitor(&mut self, monitor: LossMonitor<D>) {
242 self.val_loss_monitor = Some(monitor);
243 }
244
245 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
246 self.fit_with_validation(x, y, None, None, None, None)
247 }
248
249 pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
256 self.partial_fit_with_validation(x, y, None, None, None, None)
257 }
258
259 pub fn partial_fit_with_validation(
261 &mut self,
262 x: &Array2<f64>,
263 y: &Array1<f64>,
264 x_val: Option<&Array2<f64>>,
265 y_val: Option<&Array1<f64>>,
266 sample_weight: Option<&Array1<f64>>,
267 val_sample_weight: Option<&Array1<f64>>,
268 ) -> Result<(), &'static str> {
269 self.fit_internal(x, y, x_val, y_val, sample_weight, val_sample_weight, false)
271 }
272
273 pub fn fit_with_validation(
274 &mut self,
275 x: &Array2<f64>,
276 y: &Array1<f64>,
277 x_val: Option<&Array2<f64>>,
278 y_val: Option<&Array1<f64>>,
279 sample_weight: Option<&Array1<f64>>,
280 val_sample_weight: Option<&Array1<f64>>,
281 ) -> Result<(), &'static str> {
282 self.fit_internal(x, y, x_val, y_val, sample_weight, val_sample_weight, true)
283 }
284
285 fn validate_hyperparameters(&self) -> Result<(), &'static str> {
288 if self.n_estimators == 0 {
289 return Err("n_estimators must be greater than 0");
290 }
291 if self.learning_rate <= 0.0 {
292 return Err("learning_rate must be positive");
293 }
294 if self.learning_rate > 10.0 {
295 return Err("learning_rate > 10.0 is likely a mistake");
296 }
297 if self.minibatch_frac <= 0.0 || self.minibatch_frac > 1.0 {
298 return Err("minibatch_frac must be in (0, 1]");
299 }
300 if self.col_sample <= 0.0 || self.col_sample > 1.0 {
301 return Err("col_sample must be in (0, 1]");
302 }
303 if self.tol < 0.0 {
304 return Err("tol must be non-negative");
305 }
306 if self.validation_fraction < 0.0 || self.validation_fraction >= 1.0 {
307 return Err("validation_fraction must be in [0, 1)");
308 }
309 if self.tikhonov_reg < 0.0 {
310 return Err("tikhonov_reg must be non-negative");
311 }
312
313 match self.lr_schedule {
315 LearningRateSchedule::Linear {
316 decay_rate,
317 min_lr_fraction,
318 } => {
319 if decay_rate < 0.0 || decay_rate > 1.0 {
320 return Err("Linear schedule decay_rate must be in [0, 1]");
321 }
322 if min_lr_fraction < 0.0 || min_lr_fraction > 1.0 {
323 return Err("Linear schedule min_lr_fraction must be in [0, 1]");
324 }
325 }
326 LearningRateSchedule::Exponential { decay_rate } => {
327 if decay_rate < 0.0 {
328 return Err("Exponential schedule decay_rate must be non-negative");
329 }
330 }
331 LearningRateSchedule::CosineWarmRestarts { restart_period } => {
332 if restart_period == 0 {
333 return Err("CosineWarmRestarts restart_period must be > 0");
334 }
335 }
336 _ => {}
337 }
338
339 if let LineSearchMethod::GoldenSection { max_iters } = self.line_search_method {
341 if max_iters == 0 {
342 return Err("GoldenSection max_iters must be > 0");
343 }
344 }
345
346 Ok(())
347 }
348
349 fn fit_internal(
351 &mut self,
352 x: &Array2<f64>,
353 y: &Array1<f64>,
354 x_val: Option<&Array2<f64>>,
355 y_val: Option<&Array1<f64>>,
356 sample_weight: Option<&Array1<f64>>,
357 _val_sample_weight: Option<&Array1<f64>>,
358 reset_state: bool,
359 ) -> Result<(), &'static str> {
360 self.validate_hyperparameters()?;
362
363 if x.nrows() != y.len() {
365 return Err("Number of samples in X and y must match");
366 }
367 if x.nrows() == 0 {
368 return Err("Cannot fit to empty dataset");
369 }
370 if x.ncols() == 0 {
371 return Err("Cannot fit to dataset with no features");
372 }
373
374 if x.iter().any(|&v| !v.is_finite()) {
376 return Err("Input X contains NaN or infinite values");
377 }
378 if y.iter().any(|&v| !v.is_finite()) {
379 return Err("Input y contains NaN or infinite values");
380 }
381
382 if reset_state {
384 self.base_models.clear();
385 self.scalings.clear();
386 self.col_idxs.clear();
387 self.best_val_loss_itr = None;
388 }
389 self.n_features = Some(x.ncols());
390
391 let (x_train, y_train, x_val_auto, y_val_auto) = if self.early_stopping_rounds.is_some()
393 && x_val.is_none()
394 && y_val.is_none()
395 && self.validation_fraction > 0.0
396 && self.validation_fraction < 1.0
397 {
398 let n_samples = x.nrows();
401 let n_val = ((n_samples as f64) * self.validation_fraction) as usize;
402 let n_train = n_samples - n_val;
403
404 let mut indices: Vec<usize> = (0..n_samples).collect();
406 for i in (1..indices.len()).rev() {
407 let j = self.rng.random_range(0..=i);
408 indices.swap(i, j);
409 }
410
411 let train_indices: Vec<usize> = indices[0..n_train].to_vec();
412 let val_indices: Vec<usize> = indices[n_train..].to_vec();
413
414 let x_train = x.select(ndarray::Axis(0), &train_indices);
415 let y_train = y.select(ndarray::Axis(0), &train_indices);
416 let x_val_auto = Some(x.select(ndarray::Axis(0), &val_indices));
417 let y_val_auto = Some(y.select(ndarray::Axis(0), &val_indices));
418
419 (x_train, y_train, x_val_auto, y_val_auto)
420 } else {
421 (x.to_owned(), y.to_owned(), x_val.cloned(), y_val.cloned())
422 };
423
424 let x_train = x_train;
426 let y_train = y_train;
427 let x_val = x_val_auto.as_ref().or(x_val);
428 let y_val = y_val_auto.as_ref().or(y_val);
429
430 if let (Some(xv), Some(yv)) = (x_val, y_val) {
432 if xv.nrows() != yv.len() {
433 return Err("Number of samples in validation X and y must match");
434 }
435 if xv.ncols() != x_train.ncols() {
436 return Err("Number of features in training and validation data must match");
437 }
438 }
439
440 self.init_params = Some(D::fit(&y_train));
441 let n_params = self.init_params.as_ref().unwrap().len();
442 let mut params = Array2::from_elem((x_train.nrows(), n_params), 0.0);
443
444 let init_params = self.init_params.as_ref().unwrap();
446 params
447 .outer_iter_mut()
448 .for_each(|mut row| row.assign(init_params));
449
450 let mut val_params = if let (Some(xv), Some(_yv)) = (x_val, y_val) {
452 let mut v_params = Array2::from_elem((xv.nrows(), n_params), 0.0);
453 v_params
454 .outer_iter_mut()
455 .for_each(|mut row| row.assign(init_params));
456 Some(v_params)
457 } else {
458 None
459 };
460
461 let mut best_val_loss = f64::INFINITY;
462 let mut best_iter = 0;
463 let mut no_improvement_count = 0;
464
465 for itr in 0..self.n_estimators {
466 let dist = D::from_params(¶ms);
467
468 let grads = if self.natural_gradient && self.tikhonov_reg > 0.0 {
470 let standard_grad = Scorable::d_score(&dist, &y_train);
472 let metric = Scorable::metric(&dist);
473 crate::scores::natural_gradient_regularized(
474 &standard_grad,
475 &metric,
476 self.tikhonov_reg,
477 )
478 } else {
479 Scorable::grad(&dist, &y_train, self.natural_gradient)
480 };
481
482 let (row_idxs, col_idxs, x_sampled, y_sampled, params_sampled, weight_sampled) =
484 self.sample(&x_train, &y_train, ¶ms, sample_weight);
485 self.col_idxs.push(col_idxs.clone());
486
487 let grads_sampled = grads.select(ndarray::Axis(0), &row_idxs);
488
489 #[cfg(feature = "parallel")]
491 let fit_results: Vec<
492 Result<(Box<dyn TrainedBaseLearner>, Array1<f64>), &'static str>,
493 > = {
494 let learners: Vec<B> = (0..n_params).map(|_| self.base_learner.clone()).collect();
496 learners
497 .into_par_iter()
498 .enumerate()
499 .map(|(j, learner)| {
500 let grad_j = grads_sampled.column(j).to_owned();
501 let fitted = learner.fit_with_weights(
502 &x_sampled,
503 &grad_j,
504 weight_sampled.as_ref(),
505 )?;
506 let preds = fitted.predict(&x_sampled);
507 Ok((fitted, preds))
508 })
509 .collect()
510 };
511
512 #[cfg(not(feature = "parallel"))]
513 let fit_results: Vec<
514 Result<(Box<dyn TrainedBaseLearner>, Array1<f64>), &'static str>,
515 > = (0..n_params)
516 .map(|j| {
517 let grad_j = grads_sampled.column(j).to_owned();
518 let learner = self.base_learner.clone();
519 let fitted =
520 learner.fit_with_weights(&x_sampled, &grad_j, weight_sampled.as_ref())?;
521 let preds = fitted.predict(&x_sampled);
522 Ok((fitted, preds))
523 })
524 .collect();
525
526 let mut fitted_learners: Vec<Box<dyn TrainedBaseLearner>> =
528 Vec::with_capacity(n_params);
529 let mut predictions_cols: Vec<Array1<f64>> = Vec::with_capacity(n_params);
530 for result in fit_results {
531 let (fitted, preds) = result?;
532 fitted_learners.push(fitted);
533 predictions_cols.push(preds);
534 }
535
536 let predictions = to_2d_array(predictions_cols);
537
538 let scale = self.line_search(
539 &predictions,
540 ¶ms_sampled,
541 &y_sampled,
542 weight_sampled.as_ref(),
543 );
544 self.scalings.push(scale);
545 self.base_models.push(fitted_learners);
546
547 let progress = itr as f64 / self.n_estimators as f64;
549 let effective_learning_rate = self.compute_learning_rate(itr, progress);
550
551 let fitted_learners = self.base_models.last().unwrap();
556 let full_predictions_cols: Vec<Array1<f64>> = if col_idxs.len() == x_train.ncols() {
557 fitted_learners
558 .iter()
559 .map(|learner| learner.predict(&x_train))
560 .collect()
561 } else {
562 let x_subset = x_train.select(ndarray::Axis(1), &col_idxs);
563 fitted_learners
564 .iter()
565 .map(|learner| learner.predict(&x_subset))
566 .collect()
567 };
568 let full_predictions = to_2d_array(full_predictions_cols);
569
570 params -= &(effective_learning_rate * scale * &full_predictions);
571
572 if let (Some(xv), Some(yv), Some(vp)) = (x_val, y_val, val_params.as_mut()) {
574 let fitted_learners = self.base_models.last().unwrap();
577 let val_predictions_cols: Vec<Array1<f64>> = if col_idxs.len() == xv.ncols() {
578 fitted_learners
579 .iter()
580 .map(|learner| learner.predict(xv))
581 .collect()
582 } else {
583 let xv_subset = xv.select(ndarray::Axis(1), &col_idxs);
584 fitted_learners
585 .iter()
586 .map(|learner| learner.predict(&xv_subset))
587 .collect()
588 };
589 let val_predictions = to_2d_array(val_predictions_cols);
590 *vp -= &(effective_learning_rate * scale * &val_predictions);
591
592 let val_dist = D::from_params(vp);
594 let val_loss = if let Some(monitor) = &self.val_loss_monitor {
595 monitor(&val_dist, yv, None)
596 } else {
597 Scorable::total_score(&val_dist, yv, None)
598 };
599
600 if val_loss < best_val_loss {
602 best_val_loss = val_loss;
603 best_iter = itr;
604 no_improvement_count = 0;
605 self.best_val_loss_itr = Some(itr as usize);
606 } else {
607 no_improvement_count += 1;
608 }
609
610 if let Some(rounds) = self.early_stopping_rounds {
612 if no_improvement_count >= rounds {
613 if self.verbose {
614 println!("== Early stopping achieved.");
615 println!(
616 "== Best iteration / VAL{} (val_loss={:.4})",
617 best_iter, best_val_loss
618 );
619 }
620 break;
621 }
622 }
623
624 if self.verbose && itr % self.verbose_eval == 0 {
626 let dist = D::from_params(¶ms);
627 let train_loss = if let Some(monitor) = &self.train_loss_monitor {
628 monitor(&dist, &y_train, None)
629 } else {
630 Scorable::total_score(&dist, &y_train, None)
631 };
632 println!(
633 "[iter {}] train_loss={:.4} val_loss={:.4}",
634 itr, train_loss, val_loss
635 );
636 }
637 } else {
638 if self.verbose && itr % self.verbose_eval == 0 {
640 let dist = D::from_params(¶ms);
641 let loss = if let Some(monitor) = &self.train_loss_monitor {
642 monitor(&dist, &y_train, None)
643 } else {
644 Scorable::total_score(&dist, &y_train, None)
645 };
646
647 let grad_norm: f64 =
649 grads.iter().map(|x| x * x).sum::<f64>().sqrt() / grads.len() as f64;
650
651 println!(
652 "[iter {}] loss={:.4} grad_norm={:.4} scale={:.4}",
653 itr, loss, grad_norm, scale
654 );
655 }
656 }
657 }
658
659 Ok(())
660 }
661
662 fn sample(
663 &mut self,
664 x: &Array2<f64>,
665 y: &Array1<f64>,
666 params: &Array2<f64>,
667 sample_weight: Option<&Array1<f64>>,
668 ) -> (
669 Vec<usize>,
670 Vec<usize>,
671 Array2<f64>,
672 Array1<f64>,
673 Array2<f64>,
674 Option<Array1<f64>>,
675 ) {
676 let n_samples = x.nrows();
677 let n_features = x.ncols();
678
679 let sample_size = if self.minibatch_frac >= 1.0 {
681 n_samples
682 } else {
683 ((n_samples as f64) * self.minibatch_frac) as usize
684 };
685
686 let row_idxs: Vec<usize> = if sample_size == n_samples {
690 (0..n_samples).collect()
691 } else {
692 let mut indices: Vec<usize> = (0..n_samples).collect();
693 for i in (1..indices.len()).rev() {
695 let j = self.rng.random_range(0..=i);
696 indices.swap(i, j);
697 }
698 indices.into_iter().take(sample_size).collect()
699 };
700
701 let col_size = if self.col_sample >= 1.0 {
703 n_features
704 } else if self.col_sample > 0.0 {
705 ((n_features as f64) * self.col_sample) as usize
706 } else {
707 0
708 };
709
710 let col_idxs: Vec<usize> = if col_size == n_features || col_size == 0 {
711 (0..n_features).collect()
712 } else {
713 let mut indices: Vec<usize> = (0..n_features).collect();
714 indices.shuffle(&mut self.rng);
715 indices.into_iter().take(col_size).collect()
716 };
717
718 let x_sampled = if col_size == n_features {
722 x.select(ndarray::Axis(0), &row_idxs)
724 } else {
725 let mut result = Array2::zeros((row_idxs.len(), col_idxs.len()));
727 for (new_row, &old_row) in row_idxs.iter().enumerate() {
728 for (new_col, &old_col) in col_idxs.iter().enumerate() {
729 result[[new_row, new_col]] = x[[old_row, old_col]];
730 }
731 }
732 result
733 };
734 let y_sampled = y.select(ndarray::Axis(0), &row_idxs);
735 let params_sampled = params.select(ndarray::Axis(0), &row_idxs);
736
737 let sample_weights_sampled =
739 sample_weight.map(|weights| weights.select(ndarray::Axis(0), &row_idxs));
740
741 (
742 row_idxs,
743 col_idxs,
744 x_sampled,
745 y_sampled,
746 params_sampled,
747 sample_weights_sampled,
748 )
749 }
750
751 fn get_params(&self, x: &Array2<f64>) -> Array2<f64> {
752 self.get_params_at(x, None)
753 }
754
755 fn get_params_at(&self, x: &Array2<f64>, max_iter: Option<usize>) -> Array2<f64> {
756 if x.nrows() == 0 {
757 return Array2::zeros((0, 0));
758 }
759
760 let init_params = self
761 .init_params
762 .as_ref()
763 .expect("Model has not been fitted. Call fit() before predict().");
764 let n_params = init_params.len();
765 let mut params = Array2::from_elem((x.nrows(), n_params), 0.0);
766 params
767 .outer_iter_mut()
768 .for_each(|mut row| row.assign(init_params));
769
770 let n_iters = max_iter
771 .unwrap_or(self.base_models.len())
772 .min(self.base_models.len());
773
774 for (i, (learners, col_idx)) in self
775 .base_models
776 .iter()
777 .zip(self.col_idxs.iter())
778 .enumerate()
779 .take(n_iters)
780 {
781 let scale = self.scalings[i];
782
783 let predictions_cols: Vec<Array1<f64>> = if col_idx.len() == x.ncols() {
786 learners.iter().map(|learner| learner.predict(x)).collect()
787 } else {
788 let x_subset = x.select(ndarray::Axis(1), col_idx);
789 learners
790 .iter()
791 .map(|learner| learner.predict(&x_subset))
792 .collect()
793 };
794
795 let predictions = to_2d_array(predictions_cols);
796
797 params -= &(self.learning_rate * scale * &predictions);
798 }
799 params
800 }
801
802 pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
804 self.get_params(x)
805 }
806
807 pub fn pred_param_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
809 self.get_params_at(x, Some(max_iter))
810 }
811
812 pub fn pred_dist(&self, x: &Array2<f64>) -> D {
813 let params = self.get_params(x);
814 D::from_params(¶ms)
815 }
816
817 pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> D {
819 let params = self.get_params_at(x, Some(max_iter));
820 D::from_params(¶ms)
821 }
822
823 pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
824 self.pred_dist(x).predict()
825 }
826
827 pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
829 self.pred_dist_at(x, max_iter).predict()
830 }
831
832 pub fn staged_predict<'a>(
834 &'a self,
835 x: &'a Array2<f64>,
836 ) -> impl Iterator<Item = Array1<f64>> + 'a {
837 (1..=self.base_models.len()).map(move |i| self.predict_at(x, i))
838 }
839
840 pub fn staged_pred_dist<'a>(&'a self, x: &'a Array2<f64>) -> impl Iterator<Item = D> + 'a {
842 (1..=self.base_models.len()).map(move |i| self.pred_dist_at(x, i))
843 }
844
845 pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
847 let dist = self.pred_dist(x);
848 Scorable::total_score(&dist, y, None)
849 }
850
851 pub fn n_features(&self) -> Option<usize> {
853 self.n_features
854 }
855
856 fn compute_learning_rate(&self, iteration: u32, progress: f64) -> f64 {
858 if self.adaptive_learning_rate {
860 return self.learning_rate * (1.0 - 0.7 * progress).max(0.1);
861 }
862
863 match self.lr_schedule {
864 LearningRateSchedule::Constant => self.learning_rate,
865 LearningRateSchedule::Linear {
866 decay_rate,
867 min_lr_fraction,
868 } => self.learning_rate * (1.0 - decay_rate * progress).max(min_lr_fraction),
869 LearningRateSchedule::Exponential { decay_rate } => {
870 self.learning_rate * (-decay_rate * progress).exp()
871 }
872 LearningRateSchedule::Cosine => {
873 self.learning_rate * 0.5 * (1.0 + (std::f64::consts::PI * progress).cos())
874 }
875 LearningRateSchedule::CosineWarmRestarts { restart_period } => {
876 let period_progress = (iteration % restart_period) as f64 / restart_period as f64;
877 self.learning_rate * 0.5 * (1.0 + (std::f64::consts::PI * period_progress).cos())
878 }
879 }
880 }
881
882 pub fn feature_importances(&self) -> Option<Array2<f64>> {
887 let n_features = self.n_features?;
888 if self.base_models.is_empty() || n_features == 0 {
889 return None;
890 }
891
892 let n_params = self.init_params.as_ref()?.len();
893 let mut importances = Array2::zeros((n_params, n_features));
894
895 for (iter_idx, learners) in self.base_models.iter().enumerate() {
897 let scale = self.scalings[iter_idx].abs();
898
899 for (param_idx, learner) in learners.iter().enumerate() {
900 if let Some(feature_idx) = learner.split_feature() {
901 if feature_idx < n_features {
902 importances[[param_idx, feature_idx]] += scale;
903 }
904 }
905 }
906 }
907
908 for mut row in importances.rows_mut() {
910 let sum: f64 = row.sum();
911 if sum > 0.0 {
912 row.mapv_inplace(|v| v / sum);
913 }
914 }
915
916 Some(importances)
917 }
918
919 pub fn calibrate_uncertainty(
922 &mut self,
923 x_val: &Array2<f64>,
924 y_val: &Array1<f64>,
925 ) -> Result<(), &'static str> {
926 if self.base_models.is_empty() {
927 return Err("Model must be trained before calibration");
928 }
929
930 let params = self.pred_param(x_val);
932 let dist = D::from_params(¶ms);
933
934 let predictions = dist.predict();
936 let errors = y_val - &predictions;
937
938 let empirical_var = errors.mapv(|e| e * e).mean().unwrap_or(1.0);
940
941 if let Some(init_params) = self.init_params.as_mut() {
943 if init_params.len() >= 2 {
944 let current_var = (-init_params[1]).exp(); let target_var = empirical_var;
947 let calibration_factor = (target_var / current_var).sqrt();
948 init_params[1] += calibration_factor.ln();
949 }
950 }
951
952 Ok(())
953 }
954
955 pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
958 let importances = self.feature_importances()?;
959 let mut aggregated = importances.sum_axis(ndarray::Axis(0));
960
961 let sum: f64 = aggregated.sum();
962 if sum > 0.0 {
963 aggregated.mapv_inplace(|v| v / sum);
964 }
965
966 Some(aggregated)
967 }
968
969 fn line_search(
970 &self,
971 resids: &Array2<f64>,
972 start: &Array2<f64>,
973 y: &Array1<f64>,
974 sample_weight: Option<&Array1<f64>>,
975 ) -> f64 {
976 match self.line_search_method {
977 LineSearchMethod::Binary => self.line_search_binary(resids, start, y, sample_weight),
978 LineSearchMethod::GoldenSection { max_iters } => {
979 self.line_search_golden_section(resids, start, y, sample_weight, max_iters)
980 }
981 }
982 }
983
984 fn line_search_binary(
986 &self,
987 resids: &Array2<f64>,
988 start: &Array2<f64>,
989 y: &Array1<f64>,
990 sample_weight: Option<&Array1<f64>>,
991 ) -> f64 {
992 let mut scale = 1.0;
993 let initial_score = Scorable::total_score(&D::from_params(start), y, sample_weight);
994
995 loop {
997 if scale > 256.0 {
998 break;
999 }
1000 let scaled_resids = resids * (scale * 2.0);
1001 let next_params = start - &scaled_resids;
1002 let score = Scorable::total_score(&D::from_params(&next_params), y, sample_weight);
1003 if score >= initial_score || !score.is_finite() {
1004 break;
1005 }
1006 scale *= 2.0;
1007 }
1008
1009 loop {
1011 let scaled_resids = resids * scale;
1012 let norm: f64 = scaled_resids
1013 .rows()
1014 .into_iter()
1015 .map(|row| row.iter().map(|x| x * x).sum::<f64>().sqrt())
1016 .sum::<f64>()
1017 / scaled_resids.nrows() as f64;
1018 if norm < self.tol {
1019 break;
1020 }
1021
1022 let next_params = start - &scaled_resids;
1023 let score = Scorable::total_score(&D::from_params(&next_params), y, sample_weight);
1024 if score < initial_score && score.is_finite() {
1025 break;
1026 }
1027 scale *= 0.5;
1028
1029 if scale < 1e-10 {
1030 break;
1031 }
1032 }
1033
1034 scale
1035 }
1036
1037 fn line_search_golden_section(
1040 &self,
1041 resids: &Array2<f64>,
1042 start: &Array2<f64>,
1043 y: &Array1<f64>,
1044 sample_weight: Option<&Array1<f64>>,
1045 max_iters: usize,
1046 ) -> f64 {
1047 let compute_score = |scale: f64| -> f64 {
1049 let scaled_resids = resids * scale;
1050 let next_params = start - &scaled_resids;
1051 Scorable::total_score(&D::from_params(&next_params), y, sample_weight)
1052 };
1053
1054 let initial_score = compute_score(0.0);
1055
1056 let mut upper = 1.0;
1058 while upper < 256.0 {
1059 let score = compute_score(upper * 2.0);
1060 if score >= initial_score || !score.is_finite() {
1061 break;
1062 }
1063 upper *= 2.0;
1064 }
1065
1066 let mut a = 0.0;
1068 let mut b = upper;
1069 let inv_phi = 1.0 / GOLDEN_RATIO;
1070 let _inv_phi2 = 1.0 / (GOLDEN_RATIO * GOLDEN_RATIO); let mut c = b - (b - a) * inv_phi;
1074 let mut d = a + (b - a) * inv_phi;
1075 let mut fc = compute_score(c);
1076 let mut fd = compute_score(d);
1077
1078 for _ in 0..max_iters {
1079 if (b - a).abs() < self.tol {
1080 break;
1081 }
1082
1083 if fc < fd {
1084 b = d;
1086 d = c;
1087 fd = fc;
1088 c = b - (b - a) * inv_phi;
1089 fc = compute_score(c);
1090 } else {
1091 a = c;
1093 c = d;
1094 fc = fd;
1095 d = a + (b - a) * inv_phi;
1096 fd = compute_score(d);
1097 }
1098 }
1099
1100 let scale = (a + b) / 2.0;
1102
1103 let final_score = compute_score(scale);
1105 if final_score < initial_score && final_score.is_finite() {
1106 scale
1107 } else {
1108 1.0
1110 }
1111 }
1112
1113 pub fn serialize(&self) -> Result<SerializedNGBoost, Box<dyn std::error::Error>> {
1115 let serialized_base_models: Vec<Vec<crate::learners::SerializableTrainedLearner>> = self
1117 .base_models
1118 .iter()
1119 .map(|learners| {
1120 learners
1121 .iter()
1122 .filter_map(|learner| learner.to_serializable())
1123 .collect()
1124 })
1125 .collect();
1126
1127 Ok(SerializedNGBoost {
1128 n_estimators: self.n_estimators,
1129 learning_rate: self.learning_rate,
1130 natural_gradient: self.natural_gradient,
1131 minibatch_frac: self.minibatch_frac,
1132 col_sample: self.col_sample,
1133 verbose: self.verbose,
1134 verbose_eval: self.verbose_eval,
1135 tol: self.tol,
1136 early_stopping_rounds: self.early_stopping_rounds,
1137 validation_fraction: self.validation_fraction,
1138 init_params: self.init_params.as_ref().map(|p| p.to_vec()),
1139 scalings: self.scalings.clone(),
1140 col_idxs: self.col_idxs.clone(),
1141 best_val_loss_itr: self.best_val_loss_itr,
1142 base_models: serialized_base_models,
1143 lr_schedule: self.lr_schedule,
1144 tikhonov_reg: self.tikhonov_reg,
1145 line_search_method: self.line_search_method,
1146 n_features: self.n_features,
1147 })
1148 }
1149
1150 pub fn deserialize(
1152 serialized: SerializedNGBoost,
1153 base_learner: B,
1154 ) -> Result<Self, Box<dyn std::error::Error>>
1155 where
1156 D: Distribution + Scorable<S> + Clone,
1157 S: Score,
1158 B: BaseLearner + Clone,
1159 {
1160 let mut model = Self::with_options(
1161 serialized.n_estimators,
1162 serialized.learning_rate,
1163 base_learner,
1164 serialized.natural_gradient,
1165 serialized.minibatch_frac,
1166 serialized.col_sample,
1167 serialized.verbose,
1168 serialized.verbose_eval,
1169 serialized.tol,
1170 serialized.early_stopping_rounds,
1171 serialized.validation_fraction,
1172 false, );
1174
1175 if let Some(init_params) = serialized.init_params {
1177 model.init_params = Some(Array1::from(init_params));
1178 }
1179 model.scalings = serialized.scalings;
1180 model.col_idxs = serialized.col_idxs;
1181 model.best_val_loss_itr = serialized.best_val_loss_itr;
1182
1183 model.lr_schedule = serialized.lr_schedule;
1185 model.tikhonov_reg = serialized.tikhonov_reg;
1186 model.line_search_method = serialized.line_search_method;
1187 model.n_features = serialized.n_features;
1188
1189 model.base_models = serialized
1191 .base_models
1192 .into_iter()
1193 .map(|learners| learners.into_iter().map(|l| l.to_trait_object()).collect())
1194 .collect();
1195
1196 Ok(model)
1197 }
1198}
1199
1200#[derive(serde::Serialize, serde::Deserialize)]
1202pub struct SerializedNGBoost {
1203 pub n_estimators: u32,
1204 pub learning_rate: f64,
1205 pub natural_gradient: bool,
1206 pub minibatch_frac: f64,
1207 pub col_sample: f64,
1208 pub verbose: bool,
1209 pub verbose_eval: u32,
1210 pub tol: f64,
1211 pub early_stopping_rounds: Option<u32>,
1212 pub validation_fraction: f64,
1213 pub init_params: Option<Vec<f64>>,
1214 pub scalings: Vec<f64>,
1215 pub col_idxs: Vec<Vec<usize>>,
1216 pub best_val_loss_itr: Option<usize>,
1217 pub base_models: Vec<Vec<crate::learners::SerializableTrainedLearner>>,
1219 #[serde(default)]
1221 pub lr_schedule: LearningRateSchedule,
1222 #[serde(default)]
1224 pub tikhonov_reg: f64,
1225 #[serde(default)]
1227 pub line_search_method: LineSearchMethod,
1228 #[serde(default)]
1230 pub n_features: Option<usize>,
1231}
1232
1233fn to_2d_array(cols: Vec<Array1<f64>>) -> Array2<f64> {
1234 if cols.is_empty() {
1235 return Array2::zeros((0, 0));
1236 }
1237 let nrows = cols[0].len();
1238 let ncols = cols.len();
1239 let mut arr = Array2::zeros((nrows, ncols));
1240 for (j, col) in cols.iter().enumerate() {
1241 arr.column_mut(j).assign(col);
1242 }
1243 arr
1244}
1245
1246pub struct NGBRegressor {
1248 model: NGBoost<Normal, LogScore, DecisionTreeLearner>,
1249}
1250
1251pub struct NGBClassifier {
1252 model: NGBoost<Bernoulli, LogScore, DecisionTreeLearner>,
1253}
1254
1255impl NGBRegressor {
1256 pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
1257 Self {
1258 model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
1259 }
1260 }
1261
1262 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1263 self.model.fit(x, y)
1264 }
1265
1266 pub fn fit_with_validation(
1267 &mut self,
1268 x: &Array2<f64>,
1269 y: &Array1<f64>,
1270 x_val: Option<&Array2<f64>>,
1271 y_val: Option<&Array1<f64>>,
1272 ) -> Result<(), &'static str> {
1273 self.model
1274 .fit_with_validation(x, y, x_val, y_val, None, None)
1275 }
1276
1277 pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1284 self.model.partial_fit(x, y)
1285 }
1286
1287 pub fn partial_fit_with_validation(
1289 &mut self,
1290 x: &Array2<f64>,
1291 y: &Array1<f64>,
1292 x_val: Option<&Array2<f64>>,
1293 y_val: Option<&Array1<f64>>,
1294 ) -> Result<(), &'static str> {
1295 self.model
1296 .partial_fit_with_validation(x, y, x_val, y_val, None, None)
1297 }
1298
1299 pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1300 self.model.predict(x)
1301 }
1302
1303 pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1305 self.model.predict_at(x, max_iter)
1306 }
1307
1308 pub fn staged_predict<'a>(
1310 &'a self,
1311 x: &'a Array2<f64>,
1312 ) -> impl Iterator<Item = Array1<f64>> + 'a {
1313 self.model.staged_predict(x)
1314 }
1315
1316 pub fn pred_dist(&self, x: &Array2<f64>) -> Normal {
1317 self.model.pred_dist(x)
1318 }
1319
1320 pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Normal {
1322 self.model.pred_dist_at(x, max_iter)
1323 }
1324
1325 pub fn staged_pred_dist<'a>(&'a self, x: &'a Array2<f64>) -> impl Iterator<Item = Normal> + 'a {
1327 self.model.staged_pred_dist(x)
1328 }
1329
1330 pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
1332 self.model.pred_param(x)
1333 }
1334
1335 pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
1337 self.model.score(x, y)
1338 }
1339
1340 pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
1342 where
1343 F: Fn(&Normal, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1344 {
1345 self.model.set_train_loss_monitor(Box::new(monitor));
1346 }
1347
1348 pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
1350 where
1351 F: Fn(&Normal, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1352 {
1353 self.model.set_val_loss_monitor(Box::new(monitor));
1354 }
1355
1356 pub fn with_options(
1358 n_estimators: u32,
1359 learning_rate: f64,
1360 natural_gradient: bool,
1361 minibatch_frac: f64,
1362 col_sample: f64,
1363 verbose: bool,
1364 verbose_eval: u32,
1365 tol: f64,
1366 early_stopping_rounds: Option<u32>,
1367 validation_fraction: f64,
1368 adaptive_learning_rate: bool,
1369 ) -> Self {
1370 Self {
1371 model: NGBoost::with_options(
1372 n_estimators,
1373 learning_rate,
1374 default_tree_learner(),
1375 natural_gradient,
1376 minibatch_frac,
1377 col_sample,
1378 verbose,
1379 verbose_eval,
1380 tol,
1381 early_stopping_rounds,
1382 validation_fraction,
1383 adaptive_learning_rate,
1384 ),
1385 }
1386 }
1387
1388 pub fn with_options_compat(
1390 n_estimators: u32,
1391 learning_rate: f64,
1392 natural_gradient: bool,
1393 minibatch_frac: f64,
1394 col_sample: f64,
1395 verbose: bool,
1396 verbose_eval: u32,
1397 tol: f64,
1398 early_stopping_rounds: Option<u32>,
1399 validation_fraction: f64,
1400 ) -> Self {
1401 Self::with_options(
1402 n_estimators,
1403 learning_rate,
1404 natural_gradient,
1405 minibatch_frac,
1406 col_sample,
1407 verbose,
1408 verbose_eval,
1409 tol,
1410 early_stopping_rounds,
1411 validation_fraction,
1412 false, )
1414 }
1415
1416 pub fn set_adaptive_learning_rate(&mut self, enabled: bool) {
1418 self.model.adaptive_learning_rate = enabled;
1419 }
1420
1421 pub fn calibrate_uncertainty(
1424 &mut self,
1425 x_val: &Array2<f64>,
1426 y_val: &Array1<f64>,
1427 ) -> Result<(), &'static str> {
1428 self.model.calibrate_uncertainty(x_val, y_val)
1429 }
1430
1431 pub fn n_estimators(&self) -> u32 {
1433 self.model.n_estimators
1434 }
1435
1436 pub fn learning_rate(&self) -> f64 {
1438 self.model.learning_rate
1439 }
1440
1441 pub fn natural_gradient(&self) -> bool {
1443 self.model.natural_gradient
1444 }
1445
1446 pub fn minibatch_frac(&self) -> f64 {
1448 self.model.minibatch_frac
1449 }
1450
1451 pub fn col_sample(&self) -> f64 {
1453 self.model.col_sample
1454 }
1455
1456 pub fn best_val_loss_itr(&self) -> Option<usize> {
1458 self.model.best_val_loss_itr
1459 }
1460
1461 pub fn early_stopping_rounds(&self) -> Option<u32> {
1463 self.model.early_stopping_rounds
1464 }
1465
1466 pub fn validation_fraction(&self) -> f64 {
1468 self.model.validation_fraction
1469 }
1470
1471 pub fn n_features(&self) -> Option<usize> {
1473 self.model.n_features()
1474 }
1475
1476 pub fn feature_importances(&self) -> Option<Array2<f64>> {
1479 self.model.feature_importances()
1480 }
1481
1482 pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
1485 self.model.feature_importances_aggregated()
1486 }
1487
1488 pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1490 let serialized = self.model.serialize()?;
1491 let encoded = bincode::serialize(&serialized)?;
1492 std::fs::write(path, encoded)?;
1493 Ok(())
1494 }
1495
1496 pub fn load_model(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
1498 let encoded = std::fs::read(path)?;
1499 let serialized: SerializedNGBoost = bincode::deserialize(&encoded)?;
1500 let model = NGBoost::<Normal, LogScore, DecisionTreeLearner>::deserialize(
1501 serialized,
1502 default_tree_learner(),
1503 )?;
1504 Ok(Self { model })
1505 }
1506}
1507
1508impl NGBClassifier {
1509 pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
1510 Self {
1511 model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
1512 }
1513 }
1514
1515 pub fn with_options(
1516 n_estimators: u32,
1517 learning_rate: f64,
1518 natural_gradient: bool,
1519 minibatch_frac: f64,
1520 col_sample: f64,
1521 verbose: bool,
1522 verbose_eval: u32,
1523 tol: f64,
1524 early_stopping_rounds: Option<u32>,
1525 validation_fraction: f64,
1526 adaptive_learning_rate: bool,
1527 ) -> Self {
1528 Self {
1529 model: NGBoost::with_options(
1530 n_estimators,
1531 learning_rate,
1532 default_tree_learner(),
1533 natural_gradient,
1534 minibatch_frac,
1535 col_sample,
1536 verbose,
1537 verbose_eval,
1538 tol,
1539 early_stopping_rounds,
1540 validation_fraction,
1541 adaptive_learning_rate,
1542 ),
1543 }
1544 }
1545
1546 pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
1548 where
1549 F: Fn(&Bernoulli, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1550 {
1551 self.model.set_train_loss_monitor(Box::new(monitor));
1552 }
1553
1554 pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
1556 where
1557 F: Fn(&Bernoulli, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1558 {
1559 self.model.set_val_loss_monitor(Box::new(monitor));
1560 }
1561
1562 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1563 self.model.fit(x, y)
1564 }
1565
1566 pub fn fit_with_validation(
1567 &mut self,
1568 x: &Array2<f64>,
1569 y: &Array1<f64>,
1570 x_val: Option<&Array2<f64>>,
1571 y_val: Option<&Array1<f64>>,
1572 ) -> Result<(), &'static str> {
1573 self.model
1574 .fit_with_validation(x, y, x_val, y_val, None, None)
1575 }
1576
1577 pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1584 self.model.partial_fit(x, y)
1585 }
1586
1587 pub fn partial_fit_with_validation(
1589 &mut self,
1590 x: &Array2<f64>,
1591 y: &Array1<f64>,
1592 x_val: Option<&Array2<f64>>,
1593 y_val: Option<&Array1<f64>>,
1594 ) -> Result<(), &'static str> {
1595 self.model
1596 .partial_fit_with_validation(x, y, x_val, y_val, None, None)
1597 }
1598
1599 pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1600 self.model.predict(x)
1601 }
1602
1603 pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1605 self.model.predict_at(x, max_iter)
1606 }
1607
1608 pub fn staged_predict<'a>(
1610 &'a self,
1611 x: &'a Array2<f64>,
1612 ) -> impl Iterator<Item = Array1<f64>> + 'a {
1613 self.model.staged_predict(x)
1614 }
1615
1616 pub fn predict_proba(&self, x: &Array2<f64>) -> Array2<f64> {
1617 let dist = self.model.pred_dist(x);
1618 dist.class_probs()
1619 }
1620
1621 pub fn predict_proba_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
1623 let dist = self.model.pred_dist_at(x, max_iter);
1624 dist.class_probs()
1625 }
1626
1627 pub fn staged_predict_proba<'a>(
1629 &'a self,
1630 x: &'a Array2<f64>,
1631 ) -> impl Iterator<Item = Array2<f64>> + 'a {
1632 (1..=self.model.base_models.len()).map(move |i| self.predict_proba_at(x, i))
1633 }
1634
1635 pub fn pred_dist(&self, x: &Array2<f64>) -> Bernoulli {
1636 self.model.pred_dist(x)
1637 }
1638
1639 pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Bernoulli {
1641 self.model.pred_dist_at(x, max_iter)
1642 }
1643
1644 pub fn staged_pred_dist<'a>(
1646 &'a self,
1647 x: &'a Array2<f64>,
1648 ) -> impl Iterator<Item = Bernoulli> + 'a {
1649 self.model.staged_pred_dist(x)
1650 }
1651
1652 pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
1654 self.model.pred_param(x)
1655 }
1656
1657 pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
1659 self.model.score(x, y)
1660 }
1661
1662 pub fn n_estimators(&self) -> u32 {
1664 self.model.n_estimators
1665 }
1666
1667 pub fn learning_rate(&self) -> f64 {
1669 self.model.learning_rate
1670 }
1671
1672 pub fn natural_gradient(&self) -> bool {
1674 self.model.natural_gradient
1675 }
1676
1677 pub fn minibatch_frac(&self) -> f64 {
1679 self.model.minibatch_frac
1680 }
1681
1682 pub fn col_sample(&self) -> f64 {
1684 self.model.col_sample
1685 }
1686
1687 pub fn best_val_loss_itr(&self) -> Option<usize> {
1689 self.model.best_val_loss_itr
1690 }
1691
1692 pub fn early_stopping_rounds(&self) -> Option<u32> {
1694 self.model.early_stopping_rounds
1695 }
1696
1697 pub fn validation_fraction(&self) -> f64 {
1699 self.model.validation_fraction
1700 }
1701
1702 pub fn n_features(&self) -> Option<usize> {
1704 self.model.n_features()
1705 }
1706
1707 pub fn feature_importances(&self) -> Option<Array2<f64>> {
1710 self.model.feature_importances()
1711 }
1712
1713 pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
1716 self.model.feature_importances_aggregated()
1717 }
1718
1719 pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1721 let serialized = self.model.serialize()?;
1722 let encoded = bincode::serialize(&serialized)?;
1723 std::fs::write(path, encoded)?;
1724 Ok(())
1725 }
1726
1727 pub fn load_model(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
1729 let encoded = std::fs::read(path)?;
1730 let serialized: SerializedNGBoost = bincode::deserialize(&encoded)?;
1731 let model = NGBoost::<Bernoulli, LogScore, DecisionTreeLearner>::deserialize(
1732 serialized,
1733 default_tree_learner(),
1734 )?;
1735 Ok(Self { model })
1736 }
1737}