ngboost_rs/
ngboost.rs

1/// Type alias for loss monitor functions
2pub type LossMonitor<D> = Box<dyn Fn(&D, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync>;
3
4use crate::dist::categorical::Bernoulli;
5use crate::dist::normal::Normal;
6use crate::dist::{ClassificationDistn, Distribution};
7use crate::learners::{default_tree_learner, BaseLearner, DecisionTreeLearner, TrainedBaseLearner};
8use crate::scores::{LogScore, Scorable, Score};
9use ndarray::{Array1, Array2};
10use rand::prelude::*;
11use rand::rng;
12use std::marker::PhantomData;
13
14#[cfg(feature = "parallel")]
15use rayon::prelude::*;
16
17/// Learning rate schedule for controlling step size during training.
18#[derive(Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize)]
19pub enum LearningRateSchedule {
20    /// Constant learning rate throughout training.
21    #[default]
22    Constant,
23    /// Linear decay: lr * (1 - decay_rate * progress), clamped to min_lr.
24    /// Default: decay_rate=0.7, min_lr=0.1
25    Linear {
26        decay_rate: f64,
27        min_lr_fraction: f64,
28    },
29    /// Exponential decay: lr * exp(-decay_rate * progress).
30    Exponential { decay_rate: f64 },
31    /// Cosine annealing: lr * 0.5 * (1 + cos(pi * progress)).
32    /// Proven effective for probabilistic models.
33    Cosine,
34    /// Cosine annealing with warm restarts.
35    /// Restarts the schedule every `restart_period` iterations.
36    CosineWarmRestarts { restart_period: u32 },
37}
38
39/// Line search method for finding optimal step size.
40#[derive(Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize)]
41pub enum LineSearchMethod {
42    /// Binary search (original NGBoost method): scale up then scale down by 2x.
43    /// Fast but may miss optimal step size.
44    #[default]
45    Binary,
46    /// Golden section search: more accurate but slightly slower.
47    /// Uses the golden ratio to efficiently narrow down the optimal step size.
48    /// Generally finds better step sizes with fewer function evaluations.
49    GoldenSection {
50        /// Maximum number of iterations (default: 20)
51        max_iters: usize,
52    },
53}
54
55/// Golden ratio constant for golden section search.
56const GOLDEN_RATIO: f64 = 1.618033988749895;
57
58pub struct NGBoost<D, S, B>
59where
60    D: Distribution + Scorable<S> + Clone,
61    S: Score,
62    B: BaseLearner + Clone,
63{
64    // Hyperparameters
65    pub n_estimators: u32,
66    pub learning_rate: f64,
67    pub natural_gradient: bool,
68    pub minibatch_frac: f64,
69    pub col_sample: f64,
70    pub verbose: bool,
71    pub verbose_eval: u32,
72    pub tol: f64,
73    pub early_stopping_rounds: Option<u32>,
74    pub validation_fraction: f64,
75    pub adaptive_learning_rate: bool, // Enable adaptive learning rate for better convergence (deprecated, use lr_schedule)
76    /// Learning rate schedule for controlling step size during training.
77    pub lr_schedule: LearningRateSchedule,
78    /// Tikhonov regularization parameter for stabilizing Fisher matrix inversion.
79    /// Added to diagonal of Fisher Information Matrix: F + tikhonov_reg * I.
80    /// Set to 0.0 to disable (default). Typical values: 1e-6 to 1e-3.
81    pub tikhonov_reg: f64,
82    /// Line search method for finding optimal step size.
83    pub line_search_method: LineSearchMethod,
84
85    // Base learner
86    base_learner: B,
87
88    // State
89    pub base_models: Vec<Vec<Box<dyn TrainedBaseLearner>>>,
90    pub scalings: Vec<f64>,
91    pub init_params: Option<Array1<f64>>,
92    pub col_idxs: Vec<Vec<usize>>,
93    train_loss_monitor: Option<LossMonitor<D>>,
94    val_loss_monitor: Option<LossMonitor<D>>,
95    best_val_loss_itr: Option<usize>,
96    n_features: Option<usize>,
97
98    // Random number generator
99    rng: ThreadRng,
100
101    // Generics
102    _dist: PhantomData<D>,
103    _score: PhantomData<S>,
104}
105
106impl<D, S, B> NGBoost<D, S, B>
107where
108    D: Distribution + Scorable<S> + Clone,
109    S: Score,
110    B: BaseLearner + Clone,
111{
112    pub fn new(n_estimators: u32, learning_rate: f64, base_learner: B) -> Self {
113        NGBoost {
114            n_estimators,
115            learning_rate,
116            natural_gradient: true,
117            minibatch_frac: 1.0,
118            col_sample: 1.0,
119            verbose: false,
120            verbose_eval: 100,
121            tol: 1e-4,
122            early_stopping_rounds: None,
123            validation_fraction: 0.1,
124            adaptive_learning_rate: false,
125            lr_schedule: LearningRateSchedule::Constant,
126            tikhonov_reg: 0.0,
127            line_search_method: LineSearchMethod::Binary,
128            base_learner,
129            base_models: Vec::new(),
130            scalings: Vec::new(),
131            init_params: None,
132            col_idxs: Vec::new(),
133            train_loss_monitor: None,
134            val_loss_monitor: None,
135            best_val_loss_itr: None,
136            n_features: None,
137            rng: rng(),
138            _dist: PhantomData,
139            _score: PhantomData,
140        }
141    }
142
143    pub fn with_options(
144        n_estimators: u32,
145        learning_rate: f64,
146        base_learner: B,
147        natural_gradient: bool,
148        minibatch_frac: f64,
149        col_sample: f64,
150        verbose: bool,
151        verbose_eval: u32,
152        tol: f64,
153        early_stopping_rounds: Option<u32>,
154        validation_fraction: f64,
155        adaptive_learning_rate: bool,
156    ) -> Self {
157        NGBoost {
158            n_estimators,
159            learning_rate,
160            natural_gradient,
161            minibatch_frac,
162            col_sample,
163            verbose,
164            verbose_eval,
165            tol,
166            early_stopping_rounds,
167            validation_fraction,
168            adaptive_learning_rate,
169            lr_schedule: LearningRateSchedule::Constant,
170            tikhonov_reg: 0.0,
171            line_search_method: LineSearchMethod::Binary,
172            base_learner,
173            base_models: Vec::new(),
174            scalings: Vec::new(),
175            init_params: None,
176            col_idxs: Vec::new(),
177            train_loss_monitor: None,
178            val_loss_monitor: None,
179            best_val_loss_itr: None,
180            n_features: None,
181            rng: rng(),
182            _dist: PhantomData,
183            _score: PhantomData,
184        }
185    }
186
187    /// Create NGBoost with advanced options including learning rate schedule and regularization.
188    #[allow(clippy::too_many_arguments)]
189    pub fn with_advanced_options(
190        n_estimators: u32,
191        learning_rate: f64,
192        base_learner: B,
193        natural_gradient: bool,
194        minibatch_frac: f64,
195        col_sample: f64,
196        verbose: bool,
197        verbose_eval: u32,
198        tol: f64,
199        early_stopping_rounds: Option<u32>,
200        validation_fraction: f64,
201        lr_schedule: LearningRateSchedule,
202        tikhonov_reg: f64,
203        line_search_method: LineSearchMethod,
204    ) -> Self {
205        NGBoost {
206            n_estimators,
207            learning_rate,
208            natural_gradient,
209            minibatch_frac,
210            col_sample,
211            verbose,
212            verbose_eval,
213            tol,
214            early_stopping_rounds,
215            validation_fraction,
216            adaptive_learning_rate: false,
217            lr_schedule,
218            tikhonov_reg,
219            line_search_method,
220            base_learner,
221            base_models: Vec::new(),
222            scalings: Vec::new(),
223            init_params: None,
224            col_idxs: Vec::new(),
225            train_loss_monitor: None,
226            val_loss_monitor: None,
227            best_val_loss_itr: None,
228            n_features: None,
229            rng: rng(),
230            _dist: PhantomData,
231            _score: PhantomData,
232        }
233    }
234
235    /// Set a custom training loss monitor function
236    pub fn set_train_loss_monitor(&mut self, monitor: LossMonitor<D>) {
237        self.train_loss_monitor = Some(monitor);
238    }
239
240    /// Set a custom validation loss monitor function
241    pub fn set_val_loss_monitor(&mut self, monitor: LossMonitor<D>) {
242        self.val_loss_monitor = Some(monitor);
243    }
244
245    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
246        self.fit_with_validation(x, y, None, None, None, None)
247    }
248
249    /// Fits an NGBoost model to the data appending base models to the existing ones.
250    ///
251    /// NOTE: This method is similar to Python's partial_fit. The first call will be the most
252    /// significant and later calls will retune the model to newer data.
253    ///
254    /// Unlike `fit()`, this method does NOT reset the model state, allowing incremental learning.
255    pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
256        self.partial_fit_with_validation(x, y, None, None, None, None)
257    }
258
259    /// Partial fit with validation data support.
260    pub fn partial_fit_with_validation(
261        &mut self,
262        x: &Array2<f64>,
263        y: &Array1<f64>,
264        x_val: Option<&Array2<f64>>,
265        y_val: Option<&Array1<f64>>,
266        sample_weight: Option<&Array1<f64>>,
267        val_sample_weight: Option<&Array1<f64>>,
268    ) -> Result<(), &'static str> {
269        // Don't reset state - this is the key difference from fit()
270        self.fit_internal(x, y, x_val, y_val, sample_weight, val_sample_weight, false)
271    }
272
273    pub fn fit_with_validation(
274        &mut self,
275        x: &Array2<f64>,
276        y: &Array1<f64>,
277        x_val: Option<&Array2<f64>>,
278        y_val: Option<&Array1<f64>>,
279        sample_weight: Option<&Array1<f64>>,
280        val_sample_weight: Option<&Array1<f64>>,
281    ) -> Result<(), &'static str> {
282        self.fit_internal(x, y, x_val, y_val, sample_weight, val_sample_weight, true)
283    }
284
285    /// Validates hyperparameters before fitting.
286    /// Returns an error message if any hyperparameter is invalid.
287    fn validate_hyperparameters(&self) -> Result<(), &'static str> {
288        if self.n_estimators == 0 {
289            return Err("n_estimators must be greater than 0");
290        }
291        if self.learning_rate <= 0.0 {
292            return Err("learning_rate must be positive");
293        }
294        if self.learning_rate > 10.0 {
295            return Err("learning_rate > 10.0 is likely a mistake");
296        }
297        if self.minibatch_frac <= 0.0 || self.minibatch_frac > 1.0 {
298            return Err("minibatch_frac must be in (0, 1]");
299        }
300        if self.col_sample <= 0.0 || self.col_sample > 1.0 {
301            return Err("col_sample must be in (0, 1]");
302        }
303        if self.tol < 0.0 {
304            return Err("tol must be non-negative");
305        }
306        if self.validation_fraction < 0.0 || self.validation_fraction >= 1.0 {
307            return Err("validation_fraction must be in [0, 1)");
308        }
309        if self.tikhonov_reg < 0.0 {
310            return Err("tikhonov_reg must be non-negative");
311        }
312
313        // Validate learning rate schedule parameters
314        match self.lr_schedule {
315            LearningRateSchedule::Linear {
316                decay_rate,
317                min_lr_fraction,
318            } => {
319                if decay_rate < 0.0 || decay_rate > 1.0 {
320                    return Err("Linear schedule decay_rate must be in [0, 1]");
321                }
322                if min_lr_fraction < 0.0 || min_lr_fraction > 1.0 {
323                    return Err("Linear schedule min_lr_fraction must be in [0, 1]");
324                }
325            }
326            LearningRateSchedule::Exponential { decay_rate } => {
327                if decay_rate < 0.0 {
328                    return Err("Exponential schedule decay_rate must be non-negative");
329                }
330            }
331            LearningRateSchedule::CosineWarmRestarts { restart_period } => {
332                if restart_period == 0 {
333                    return Err("CosineWarmRestarts restart_period must be > 0");
334                }
335            }
336            _ => {}
337        }
338
339        // Validate line search parameters
340        if let LineSearchMethod::GoldenSection { max_iters } = self.line_search_method {
341            if max_iters == 0 {
342                return Err("GoldenSection max_iters must be > 0");
343            }
344        }
345
346        Ok(())
347    }
348
349    /// Internal fit implementation that can optionally reset state.
350    fn fit_internal(
351        &mut self,
352        x: &Array2<f64>,
353        y: &Array1<f64>,
354        x_val: Option<&Array2<f64>>,
355        y_val: Option<&Array1<f64>>,
356        sample_weight: Option<&Array1<f64>>,
357        _val_sample_weight: Option<&Array1<f64>>,
358        reset_state: bool,
359    ) -> Result<(), &'static str> {
360        // Validate hyperparameters first
361        self.validate_hyperparameters()?;
362
363        // Validate input dimensions with more detailed error messages
364        if x.nrows() != y.len() {
365            return Err("Number of samples in X and y must match");
366        }
367        if x.nrows() == 0 {
368            return Err("Cannot fit to empty dataset");
369        }
370        if x.ncols() == 0 {
371            return Err("Cannot fit to dataset with no features");
372        }
373
374        // Check for NaN/Inf values in input data
375        if x.iter().any(|&v| !v.is_finite()) {
376            return Err("Input X contains NaN or infinite values");
377        }
378        if y.iter().any(|&v| !v.is_finite()) {
379            return Err("Input y contains NaN or infinite values");
380        }
381
382        // Reset state only if requested (fit() resets, partial_fit() doesn't)
383        if reset_state {
384            self.base_models.clear();
385            self.scalings.clear();
386            self.col_idxs.clear();
387            self.best_val_loss_itr = None;
388        }
389        self.n_features = Some(x.ncols());
390
391        // Handle automatic validation split if early stopping is enabled
392        let (x_train, y_train, x_val_auto, y_val_auto) = if self.early_stopping_rounds.is_some()
393            && x_val.is_none()
394            && y_val.is_none()
395            && self.validation_fraction > 0.0
396            && self.validation_fraction < 1.0
397        {
398            // Split training data into training and validation sets
399            // Shuffle indices first to match sklearn's train_test_split behavior
400            let n_samples = x.nrows();
401            let n_val = ((n_samples as f64) * self.validation_fraction) as usize;
402            let n_train = n_samples - n_val;
403
404            // Shuffle indices for random split (matches Python's train_test_split)
405            let mut indices: Vec<usize> = (0..n_samples).collect();
406            for i in (1..indices.len()).rev() {
407                let j = self.rng.random_range(0..=i);
408                indices.swap(i, j);
409            }
410
411            let train_indices: Vec<usize> = indices[0..n_train].to_vec();
412            let val_indices: Vec<usize> = indices[n_train..].to_vec();
413
414            let x_train = x.select(ndarray::Axis(0), &train_indices);
415            let y_train = y.select(ndarray::Axis(0), &train_indices);
416            let x_val_auto = Some(x.select(ndarray::Axis(0), &val_indices));
417            let y_val_auto = Some(y.select(ndarray::Axis(0), &val_indices));
418
419            (x_train, y_train, x_val_auto, y_val_auto)
420        } else {
421            (x.to_owned(), y.to_owned(), x_val.cloned(), y_val.cloned())
422        };
423
424        // Use the automatically split or provided validation data
425        let x_train = x_train;
426        let y_train = y_train;
427        let x_val = x_val_auto.as_ref().or(x_val);
428        let y_val = y_val_auto.as_ref().or(y_val);
429
430        // Validate validation data if provided
431        if let (Some(xv), Some(yv)) = (x_val, y_val) {
432            if xv.nrows() != yv.len() {
433                return Err("Number of samples in validation X and y must match");
434            }
435            if xv.ncols() != x_train.ncols() {
436                return Err("Number of features in training and validation data must match");
437            }
438        }
439
440        self.init_params = Some(D::fit(&y_train));
441        let n_params = self.init_params.as_ref().unwrap().len();
442        let mut params = Array2::from_elem((x_train.nrows(), n_params), 0.0);
443
444        // Safe unwrap with proper error handling
445        let init_params = self.init_params.as_ref().unwrap();
446        params
447            .outer_iter_mut()
448            .for_each(|mut row| row.assign(init_params));
449
450        // Prepare validation params if validation data is provided
451        let mut val_params = if let (Some(xv), Some(_yv)) = (x_val, y_val) {
452            let mut v_params = Array2::from_elem((xv.nrows(), n_params), 0.0);
453            v_params
454                .outer_iter_mut()
455                .for_each(|mut row| row.assign(init_params));
456            Some(v_params)
457        } else {
458            None
459        };
460
461        let mut best_val_loss = f64::INFINITY;
462        let mut best_iter = 0;
463        let mut no_improvement_count = 0;
464
465        for itr in 0..self.n_estimators {
466            let dist = D::from_params(&params);
467
468            // Compute gradients with optional Tikhonov regularization
469            let grads = if self.natural_gradient && self.tikhonov_reg > 0.0 {
470                // Use regularized natural gradient for better numerical stability
471                let standard_grad = Scorable::d_score(&dist, &y_train);
472                let metric = Scorable::metric(&dist);
473                crate::scores::natural_gradient_regularized(
474                    &standard_grad,
475                    &metric,
476                    self.tikhonov_reg,
477                )
478            } else {
479                Scorable::grad(&dist, &y_train, self.natural_gradient)
480            };
481
482            // Sample data for this iteration
483            let (row_idxs, col_idxs, x_sampled, y_sampled, params_sampled, weight_sampled) =
484                self.sample(&x_train, &y_train, &params, sample_weight);
485            self.col_idxs.push(col_idxs.clone());
486
487            let grads_sampled = grads.select(ndarray::Axis(0), &row_idxs);
488
489            // Fit base learners for each parameter - parallelized when feature enabled
490            #[cfg(feature = "parallel")]
491            let fit_results: Vec<
492                Result<(Box<dyn TrainedBaseLearner>, Array1<f64>), &'static str>,
493            > = {
494                // Pre-clone learners to avoid borrow issues in parallel iterator
495                let learners: Vec<B> = (0..n_params).map(|_| self.base_learner.clone()).collect();
496                learners
497                    .into_par_iter()
498                    .enumerate()
499                    .map(|(j, learner)| {
500                        let grad_j = grads_sampled.column(j).to_owned();
501                        let fitted = learner.fit_with_weights(
502                            &x_sampled,
503                            &grad_j,
504                            weight_sampled.as_ref(),
505                        )?;
506                        let preds = fitted.predict(&x_sampled);
507                        Ok((fitted, preds))
508                    })
509                    .collect()
510            };
511
512            #[cfg(not(feature = "parallel"))]
513            let fit_results: Vec<
514                Result<(Box<dyn TrainedBaseLearner>, Array1<f64>), &'static str>,
515            > = (0..n_params)
516                .map(|j| {
517                    let grad_j = grads_sampled.column(j).to_owned();
518                    let learner = self.base_learner.clone();
519                    let fitted =
520                        learner.fit_with_weights(&x_sampled, &grad_j, weight_sampled.as_ref())?;
521                    let preds = fitted.predict(&x_sampled);
522                    Ok((fitted, preds))
523                })
524                .collect();
525
526            // Unpack results, propagating any errors
527            let mut fitted_learners: Vec<Box<dyn TrainedBaseLearner>> =
528                Vec::with_capacity(n_params);
529            let mut predictions_cols: Vec<Array1<f64>> = Vec::with_capacity(n_params);
530            for result in fit_results {
531                let (fitted, preds) = result?;
532                fitted_learners.push(fitted);
533                predictions_cols.push(preds);
534            }
535
536            let predictions = to_2d_array(predictions_cols);
537
538            let scale = self.line_search(
539                &predictions,
540                &params_sampled,
541                &y_sampled,
542                weight_sampled.as_ref(),
543            );
544            self.scalings.push(scale);
545            self.base_models.push(fitted_learners);
546
547            // Apply learning rate schedule
548            let progress = itr as f64 / self.n_estimators as f64;
549            let effective_learning_rate = self.compute_learning_rate(itr, progress);
550
551            // Update parameters for ALL training samples by re-predicting on full X
552            // This matches Python's behavior: after fitting base learners on minibatch,
553            // we predict on the FULL training set to update all parameters
554            // This is critical for correct convergence with minibatch_frac < 1.0
555            let fitted_learners = self.base_models.last().unwrap();
556            let full_predictions_cols: Vec<Array1<f64>> = if col_idxs.len() == x_train.ncols() {
557                fitted_learners
558                    .iter()
559                    .map(|learner| learner.predict(&x_train))
560                    .collect()
561            } else {
562                let x_subset = x_train.select(ndarray::Axis(1), &col_idxs);
563                fitted_learners
564                    .iter()
565                    .map(|learner| learner.predict(&x_subset))
566                    .collect()
567            };
568            let full_predictions = to_2d_array(full_predictions_cols);
569
570            params -= &(effective_learning_rate * scale * &full_predictions);
571
572            // Update validation parameters if validation data is provided
573            if let (Some(xv), Some(yv), Some(vp)) = (x_val, y_val, val_params.as_mut()) {
574                // Get predictions on validation data from the fitted base learners
575                // Apply column subsampling to match training
576                let fitted_learners = self.base_models.last().unwrap();
577                let val_predictions_cols: Vec<Array1<f64>> = if col_idxs.len() == xv.ncols() {
578                    fitted_learners
579                        .iter()
580                        .map(|learner| learner.predict(xv))
581                        .collect()
582                } else {
583                    let xv_subset = xv.select(ndarray::Axis(1), &col_idxs);
584                    fitted_learners
585                        .iter()
586                        .map(|learner| learner.predict(&xv_subset))
587                        .collect()
588                };
589                let val_predictions = to_2d_array(val_predictions_cols);
590                *vp -= &(effective_learning_rate * scale * &val_predictions);
591
592                // Calculate validation loss using monitor or default
593                let val_dist = D::from_params(vp);
594                let val_loss = if let Some(monitor) = &self.val_loss_monitor {
595                    monitor(&val_dist, yv, None)
596                } else {
597                    Scorable::total_score(&val_dist, yv, None)
598                };
599
600                // Early stopping logic
601                if val_loss < best_val_loss {
602                    best_val_loss = val_loss;
603                    best_iter = itr;
604                    no_improvement_count = 0;
605                    self.best_val_loss_itr = Some(itr as usize);
606                } else {
607                    no_improvement_count += 1;
608                }
609
610                // Check if we should stop early
611                if let Some(rounds) = self.early_stopping_rounds {
612                    if no_improvement_count >= rounds {
613                        if self.verbose {
614                            println!("== Early stopping achieved.");
615                            println!(
616                                "== Best iteration / VAL{} (val_loss={:.4})",
617                                best_iter, best_val_loss
618                            );
619                        }
620                        break;
621                    }
622                }
623
624                // Verbose logging with validation
625                if self.verbose && itr % self.verbose_eval == 0 {
626                    let dist = D::from_params(&params);
627                    let train_loss = if let Some(monitor) = &self.train_loss_monitor {
628                        monitor(&dist, &y_train, None)
629                    } else {
630                        Scorable::total_score(&dist, &y_train, None)
631                    };
632                    println!(
633                        "[iter {}] train_loss={:.4} val_loss={:.4}",
634                        itr, train_loss, val_loss
635                    );
636                }
637            } else {
638                // Verbose logging without validation
639                if self.verbose && itr % self.verbose_eval == 0 {
640                    let dist = D::from_params(&params);
641                    let loss = if let Some(monitor) = &self.train_loss_monitor {
642                        monitor(&dist, &y_train, None)
643                    } else {
644                        Scorable::total_score(&dist, &y_train, None)
645                    };
646
647                    // Calculate gradient norm for debugging
648                    let grad_norm: f64 =
649                        grads.iter().map(|x| x * x).sum::<f64>().sqrt() / grads.len() as f64;
650
651                    println!(
652                        "[iter {}] loss={:.4} grad_norm={:.4} scale={:.4}",
653                        itr, loss, grad_norm, scale
654                    );
655                }
656            }
657        }
658
659        Ok(())
660    }
661
662    fn sample(
663        &mut self,
664        x: &Array2<f64>,
665        y: &Array1<f64>,
666        params: &Array2<f64>,
667        sample_weight: Option<&Array1<f64>>,
668    ) -> (
669        Vec<usize>,
670        Vec<usize>,
671        Array2<f64>,
672        Array1<f64>,
673        Array2<f64>,
674        Option<Array1<f64>>,
675    ) {
676        let n_samples = x.nrows();
677        let n_features = x.ncols();
678
679        // Sample rows (minibatch)
680        let sample_size = if self.minibatch_frac >= 1.0 {
681            n_samples
682        } else {
683            ((n_samples as f64) * self.minibatch_frac) as usize
684        };
685
686        // Uniform random sampling without replacement (matches Python's np.random.choice behavior)
687        // Note: Python does NOT do weighted sampling for minibatch selection,
688        // it only passes the weights to the base learner's fit method
689        let row_idxs: Vec<usize> = if sample_size == n_samples {
690            (0..n_samples).collect()
691        } else {
692            let mut indices: Vec<usize> = (0..n_samples).collect();
693            // Use Fisher-Yates shuffle for better randomness (matches numpy's algorithm)
694            for i in (1..indices.len()).rev() {
695                let j = self.rng.random_range(0..=i);
696                indices.swap(i, j);
697            }
698            indices.into_iter().take(sample_size).collect()
699        };
700
701        // Sample columns
702        let col_size = if self.col_sample >= 1.0 {
703            n_features
704        } else if self.col_sample > 0.0 {
705            ((n_features as f64) * self.col_sample) as usize
706        } else {
707            0
708        };
709
710        let col_idxs: Vec<usize> = if col_size == n_features || col_size == 0 {
711            (0..n_features).collect()
712        } else {
713            let mut indices: Vec<usize> = (0..n_features).collect();
714            indices.shuffle(&mut self.rng);
715            indices.into_iter().take(col_size).collect()
716        };
717
718        // Create sampled data with optimized single-pass selection
719        // Instead of two sequential selects (which create an intermediate array),
720        // we directly construct the result array
721        let x_sampled = if col_size == n_features {
722            // No column sampling - just select rows (single allocation)
723            x.select(ndarray::Axis(0), &row_idxs)
724        } else {
725            // Both row and column sampling - single allocation with direct indexing
726            let mut result = Array2::zeros((row_idxs.len(), col_idxs.len()));
727            for (new_row, &old_row) in row_idxs.iter().enumerate() {
728                for (new_col, &old_col) in col_idxs.iter().enumerate() {
729                    result[[new_row, new_col]] = x[[old_row, old_col]];
730                }
731            }
732            result
733        };
734        let y_sampled = y.select(ndarray::Axis(0), &row_idxs);
735        let params_sampled = params.select(ndarray::Axis(0), &row_idxs);
736
737        // Handle sample weights
738        let sample_weights_sampled =
739            sample_weight.map(|weights| weights.select(ndarray::Axis(0), &row_idxs));
740
741        (
742            row_idxs,
743            col_idxs,
744            x_sampled,
745            y_sampled,
746            params_sampled,
747            sample_weights_sampled,
748        )
749    }
750
751    fn get_params(&self, x: &Array2<f64>) -> Array2<f64> {
752        self.get_params_at(x, None)
753    }
754
755    fn get_params_at(&self, x: &Array2<f64>, max_iter: Option<usize>) -> Array2<f64> {
756        if x.nrows() == 0 {
757            return Array2::zeros((0, 0));
758        }
759
760        let init_params = self
761            .init_params
762            .as_ref()
763            .expect("Model has not been fitted. Call fit() before predict().");
764        let n_params = init_params.len();
765        let mut params = Array2::from_elem((x.nrows(), n_params), 0.0);
766        params
767            .outer_iter_mut()
768            .for_each(|mut row| row.assign(init_params));
769
770        let n_iters = max_iter
771            .unwrap_or(self.base_models.len())
772            .min(self.base_models.len());
773
774        for (i, (learners, col_idx)) in self
775            .base_models
776            .iter()
777            .zip(self.col_idxs.iter())
778            .enumerate()
779            .take(n_iters)
780        {
781            let scale = self.scalings[i];
782
783            // Apply column subsampling during prediction to match training
784            // This is critical when col_sample < 1.0
785            let predictions_cols: Vec<Array1<f64>> = if col_idx.len() == x.ncols() {
786                learners.iter().map(|learner| learner.predict(x)).collect()
787            } else {
788                let x_subset = x.select(ndarray::Axis(1), col_idx);
789                learners
790                    .iter()
791                    .map(|learner| learner.predict(&x_subset))
792                    .collect()
793            };
794
795            let predictions = to_2d_array(predictions_cols);
796
797            params -= &(self.learning_rate * scale * &predictions);
798        }
799        params
800    }
801
802    /// Get the predicted distribution parameters (like Python's pred_param)
803    pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
804        self.get_params(x)
805    }
806
807    /// Get the predicted distribution parameters up to a specific iteration
808    pub fn pred_param_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
809        self.get_params_at(x, Some(max_iter))
810    }
811
812    pub fn pred_dist(&self, x: &Array2<f64>) -> D {
813        let params = self.get_params(x);
814        D::from_params(&params)
815    }
816
817    /// Get the predicted distribution up to a specific iteration
818    pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> D {
819        let params = self.get_params_at(x, Some(max_iter));
820        D::from_params(&params)
821    }
822
823    pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
824        self.pred_dist(x).predict()
825    }
826
827    /// Get predictions up to a specific iteration
828    pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
829        self.pred_dist_at(x, max_iter).predict()
830    }
831
832    /// Returns an iterator over staged predictions (predictions at each boosting iteration)
833    pub fn staged_predict<'a>(
834        &'a self,
835        x: &'a Array2<f64>,
836    ) -> impl Iterator<Item = Array1<f64>> + 'a {
837        (1..=self.base_models.len()).map(move |i| self.predict_at(x, i))
838    }
839
840    /// Returns an iterator over staged distribution predictions
841    pub fn staged_pred_dist<'a>(&'a self, x: &'a Array2<f64>) -> impl Iterator<Item = D> + 'a {
842        (1..=self.base_models.len()).map(move |i| self.pred_dist_at(x, i))
843    }
844
845    /// Compute the average score (loss) on the given data
846    pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
847        let dist = self.pred_dist(x);
848        Scorable::total_score(&dist, y, None)
849    }
850
851    /// Get number of features the model was trained on
852    pub fn n_features(&self) -> Option<usize> {
853        self.n_features
854    }
855
856    /// Compute the effective learning rate for the given iteration using the configured schedule.
857    fn compute_learning_rate(&self, iteration: u32, progress: f64) -> f64 {
858        // Legacy adaptive_learning_rate takes precedence for backward compatibility
859        if self.adaptive_learning_rate {
860            return self.learning_rate * (1.0 - 0.7 * progress).max(0.1);
861        }
862
863        match self.lr_schedule {
864            LearningRateSchedule::Constant => self.learning_rate,
865            LearningRateSchedule::Linear {
866                decay_rate,
867                min_lr_fraction,
868            } => self.learning_rate * (1.0 - decay_rate * progress).max(min_lr_fraction),
869            LearningRateSchedule::Exponential { decay_rate } => {
870                self.learning_rate * (-decay_rate * progress).exp()
871            }
872            LearningRateSchedule::Cosine => {
873                self.learning_rate * 0.5 * (1.0 + (std::f64::consts::PI * progress).cos())
874            }
875            LearningRateSchedule::CosineWarmRestarts { restart_period } => {
876                let period_progress = (iteration % restart_period) as f64 / restart_period as f64;
877                self.learning_rate * 0.5 * (1.0 + (std::f64::consts::PI * period_progress).cos())
878            }
879        }
880    }
881
882    /// Compute feature importances based on how often each feature is used in splits.
883    /// Returns a 2D array of shape (n_params, n_features) where each row contains
884    /// the normalized feature importances for that distribution parameter.
885    /// Returns None if the model hasn't been trained or has no features.
886    pub fn feature_importances(&self) -> Option<Array2<f64>> {
887        let n_features = self.n_features?;
888        if self.base_models.is_empty() || n_features == 0 {
889            return None;
890        }
891
892        let n_params = self.init_params.as_ref()?.len();
893        let mut importances = Array2::zeros((n_params, n_features));
894
895        // Aggregate feature usage across all iterations, weighted by scaling factor
896        for (iter_idx, learners) in self.base_models.iter().enumerate() {
897            let scale = self.scalings[iter_idx].abs();
898
899            for (param_idx, learner) in learners.iter().enumerate() {
900                if let Some(feature_idx) = learner.split_feature() {
901                    if feature_idx < n_features {
902                        importances[[param_idx, feature_idx]] += scale;
903                    }
904                }
905            }
906        }
907
908        // Normalize each parameter's importances to sum to 1
909        for mut row in importances.rows_mut() {
910            let sum: f64 = row.sum();
911            if sum > 0.0 {
912                row.mapv_inplace(|v| v / sum);
913            }
914        }
915
916        Some(importances)
917    }
918
919    /// Calibrate uncertainty estimates using isotonic regression on validation data
920    /// This improves the quality of probabilistic predictions by adjusting the variance estimates
921    pub fn calibrate_uncertainty(
922        &mut self,
923        x_val: &Array2<f64>,
924        y_val: &Array1<f64>,
925    ) -> Result<(), &'static str> {
926        if self.base_models.is_empty() {
927            return Err("Model must be trained before calibration");
928        }
929
930        // Get predictions on validation data
931        let params = self.pred_param(x_val);
932        let dist = D::from_params(&params);
933
934        // Calculate predictions and errors
935        let predictions = dist.predict();
936        let errors = y_val - &predictions;
937
938        // Calculate empirical variance
939        let empirical_var = errors.mapv(|e| e * e).mean().unwrap_or(1.0);
940
941        // For normal distribution (2 parameters), adjust the scale parameter
942        if let Some(init_params) = self.init_params.as_mut() {
943            if init_params.len() >= 2 {
944                // The second parameter is log(scale), so we adjust it based on empirical variance
945                let current_var = (-init_params[1]).exp(); // exp(2*log(scale)) = scale^2
946                let target_var = empirical_var;
947                let calibration_factor = (target_var / current_var).sqrt();
948                init_params[1] += calibration_factor.ln();
949            }
950        }
951
952        Ok(())
953    }
954
955    /// Compute aggregated feature importances across all distribution parameters.
956    /// Returns a 1D array of length n_features with normalized importances.
957    pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
958        let importances = self.feature_importances()?;
959        let mut aggregated = importances.sum_axis(ndarray::Axis(0));
960
961        let sum: f64 = aggregated.sum();
962        if sum > 0.0 {
963            aggregated.mapv_inplace(|v| v / sum);
964        }
965
966        Some(aggregated)
967    }
968
969    fn line_search(
970        &self,
971        resids: &Array2<f64>,
972        start: &Array2<f64>,
973        y: &Array1<f64>,
974        sample_weight: Option<&Array1<f64>>,
975    ) -> f64 {
976        match self.line_search_method {
977            LineSearchMethod::Binary => self.line_search_binary(resids, start, y, sample_weight),
978            LineSearchMethod::GoldenSection { max_iters } => {
979                self.line_search_golden_section(resids, start, y, sample_weight, max_iters)
980            }
981        }
982    }
983
984    /// Binary line search (original NGBoost method).
985    fn line_search_binary(
986        &self,
987        resids: &Array2<f64>,
988        start: &Array2<f64>,
989        y: &Array1<f64>,
990        sample_weight: Option<&Array1<f64>>,
991    ) -> f64 {
992        let mut scale = 1.0;
993        let initial_score = Scorable::total_score(&D::from_params(start), y, sample_weight);
994
995        // Scale up phase: try to find a larger step that still reduces loss
996        loop {
997            if scale > 256.0 {
998                break;
999            }
1000            let scaled_resids = resids * (scale * 2.0);
1001            let next_params = start - &scaled_resids;
1002            let score = Scorable::total_score(&D::from_params(&next_params), y, sample_weight);
1003            if score >= initial_score || !score.is_finite() {
1004                break;
1005            }
1006            scale *= 2.0;
1007        }
1008
1009        // Scale down phase: find a step that actually reduces loss
1010        loop {
1011            let scaled_resids = resids * scale;
1012            let norm: f64 = scaled_resids
1013                .rows()
1014                .into_iter()
1015                .map(|row| row.iter().map(|x| x * x).sum::<f64>().sqrt())
1016                .sum::<f64>()
1017                / scaled_resids.nrows() as f64;
1018            if norm < self.tol {
1019                break;
1020            }
1021
1022            let next_params = start - &scaled_resids;
1023            let score = Scorable::total_score(&D::from_params(&next_params), y, sample_weight);
1024            if score < initial_score && score.is_finite() {
1025                break;
1026            }
1027            scale *= 0.5;
1028
1029            if scale < 1e-10 {
1030                break;
1031            }
1032        }
1033
1034        scale
1035    }
1036
1037    /// Golden section line search for more accurate step size.
1038    /// Uses the golden ratio to efficiently narrow down the optimal step size.
1039    fn line_search_golden_section(
1040        &self,
1041        resids: &Array2<f64>,
1042        start: &Array2<f64>,
1043        y: &Array1<f64>,
1044        sample_weight: Option<&Array1<f64>>,
1045        max_iters: usize,
1046    ) -> f64 {
1047        // Helper to compute score at a given scale
1048        let compute_score = |scale: f64| -> f64 {
1049            let scaled_resids = resids * scale;
1050            let next_params = start - &scaled_resids;
1051            Scorable::total_score(&D::from_params(&next_params), y, sample_weight)
1052        };
1053
1054        let initial_score = compute_score(0.0);
1055
1056        // First, find a reasonable upper bound by scaling up
1057        let mut upper = 1.0;
1058        while upper < 256.0 {
1059            let score = compute_score(upper * 2.0);
1060            if score >= initial_score || !score.is_finite() {
1061                break;
1062            }
1063            upper *= 2.0;
1064        }
1065
1066        // Golden section search between 0 and upper
1067        let mut a = 0.0;
1068        let mut b = upper;
1069        let inv_phi = 1.0 / GOLDEN_RATIO;
1070        let _inv_phi2 = 1.0 / (GOLDEN_RATIO * GOLDEN_RATIO); // Available for Brent's method extension
1071
1072        // Initial interior points
1073        let mut c = b - (b - a) * inv_phi;
1074        let mut d = a + (b - a) * inv_phi;
1075        let mut fc = compute_score(c);
1076        let mut fd = compute_score(d);
1077
1078        for _ in 0..max_iters {
1079            if (b - a).abs() < self.tol {
1080                break;
1081            }
1082
1083            if fc < fd {
1084                // Minimum is in [a, d]
1085                b = d;
1086                d = c;
1087                fd = fc;
1088                c = b - (b - a) * inv_phi;
1089                fc = compute_score(c);
1090            } else {
1091                // Minimum is in [c, b]
1092                a = c;
1093                c = d;
1094                fc = fd;
1095                d = a + (b - a) * inv_phi;
1096                fd = compute_score(d);
1097            }
1098        }
1099
1100        // Return the midpoint of the final interval
1101        let scale = (a + b) / 2.0;
1102
1103        // Verify the scale actually reduces loss, otherwise fall back
1104        let final_score = compute_score(scale);
1105        if final_score < initial_score && final_score.is_finite() {
1106            scale
1107        } else {
1108            // Fall back to a small step
1109            1.0
1110        }
1111    }
1112
1113    /// Serialize the model to a platform-independent format
1114    pub fn serialize(&self) -> Result<SerializedNGBoost, Box<dyn std::error::Error>> {
1115        // Serialize base models
1116        let serialized_base_models: Vec<Vec<crate::learners::SerializableTrainedLearner>> = self
1117            .base_models
1118            .iter()
1119            .map(|learners| {
1120                learners
1121                    .iter()
1122                    .filter_map(|learner| learner.to_serializable())
1123                    .collect()
1124            })
1125            .collect();
1126
1127        Ok(SerializedNGBoost {
1128            n_estimators: self.n_estimators,
1129            learning_rate: self.learning_rate,
1130            natural_gradient: self.natural_gradient,
1131            minibatch_frac: self.minibatch_frac,
1132            col_sample: self.col_sample,
1133            verbose: self.verbose,
1134            verbose_eval: self.verbose_eval,
1135            tol: self.tol,
1136            early_stopping_rounds: self.early_stopping_rounds,
1137            validation_fraction: self.validation_fraction,
1138            init_params: self.init_params.as_ref().map(|p| p.to_vec()),
1139            scalings: self.scalings.clone(),
1140            col_idxs: self.col_idxs.clone(),
1141            best_val_loss_itr: self.best_val_loss_itr,
1142            base_models: serialized_base_models,
1143            lr_schedule: self.lr_schedule,
1144            tikhonov_reg: self.tikhonov_reg,
1145            line_search_method: self.line_search_method,
1146            n_features: self.n_features,
1147        })
1148    }
1149
1150    /// Deserialize the model from a platform-independent format
1151    pub fn deserialize(
1152        serialized: SerializedNGBoost,
1153        base_learner: B,
1154    ) -> Result<Self, Box<dyn std::error::Error>>
1155    where
1156        D: Distribution + Scorable<S> + Clone,
1157        S: Score,
1158        B: BaseLearner + Clone,
1159    {
1160        let mut model = Self::with_options(
1161            serialized.n_estimators,
1162            serialized.learning_rate,
1163            base_learner,
1164            serialized.natural_gradient,
1165            serialized.minibatch_frac,
1166            serialized.col_sample,
1167            serialized.verbose,
1168            serialized.verbose_eval,
1169            serialized.tol,
1170            serialized.early_stopping_rounds,
1171            serialized.validation_fraction,
1172            false, // Default adaptive_learning_rate to false for backward compatibility
1173        );
1174
1175        // Restore trained state
1176        if let Some(init_params) = serialized.init_params {
1177            model.init_params = Some(Array1::from(init_params));
1178        }
1179        model.scalings = serialized.scalings;
1180        model.col_idxs = serialized.col_idxs;
1181        model.best_val_loss_itr = serialized.best_val_loss_itr;
1182
1183        // Restore advanced options (added in v0.3)
1184        model.lr_schedule = serialized.lr_schedule;
1185        model.tikhonov_reg = serialized.tikhonov_reg;
1186        model.line_search_method = serialized.line_search_method;
1187        model.n_features = serialized.n_features;
1188
1189        // Restore base models
1190        model.base_models = serialized
1191            .base_models
1192            .into_iter()
1193            .map(|learners| learners.into_iter().map(|l| l.to_trait_object()).collect())
1194            .collect();
1195
1196        Ok(model)
1197    }
1198}
1199
1200/// Serialized model data structure
1201#[derive(serde::Serialize, serde::Deserialize)]
1202pub struct SerializedNGBoost {
1203    pub n_estimators: u32,
1204    pub learning_rate: f64,
1205    pub natural_gradient: bool,
1206    pub minibatch_frac: f64,
1207    pub col_sample: f64,
1208    pub verbose: bool,
1209    pub verbose_eval: u32,
1210    pub tol: f64,
1211    pub early_stopping_rounds: Option<u32>,
1212    pub validation_fraction: f64,
1213    pub init_params: Option<Vec<f64>>,
1214    pub scalings: Vec<f64>,
1215    pub col_idxs: Vec<Vec<usize>>,
1216    pub best_val_loss_itr: Option<usize>,
1217    /// Serialized base models - each inner Vec contains learners for each parameter
1218    pub base_models: Vec<Vec<crate::learners::SerializableTrainedLearner>>,
1219    /// Learning rate schedule (added in v0.3)
1220    #[serde(default)]
1221    pub lr_schedule: LearningRateSchedule,
1222    /// Tikhonov regularization parameter (added in v0.3)
1223    #[serde(default)]
1224    pub tikhonov_reg: f64,
1225    /// Line search method (added in v0.3)
1226    #[serde(default)]
1227    pub line_search_method: LineSearchMethod,
1228    /// Number of features the model was trained on (added in v0.3)
1229    #[serde(default)]
1230    pub n_features: Option<usize>,
1231}
1232
1233fn to_2d_array(cols: Vec<Array1<f64>>) -> Array2<f64> {
1234    if cols.is_empty() {
1235        return Array2::zeros((0, 0));
1236    }
1237    let nrows = cols[0].len();
1238    let ncols = cols.len();
1239    let mut arr = Array2::zeros((nrows, ncols));
1240    for (j, col) in cols.iter().enumerate() {
1241        arr.column_mut(j).assign(col);
1242    }
1243    arr
1244}
1245
1246// High-level API
1247pub struct NGBRegressor {
1248    model: NGBoost<Normal, LogScore, DecisionTreeLearner>,
1249}
1250
1251pub struct NGBClassifier {
1252    model: NGBoost<Bernoulli, LogScore, DecisionTreeLearner>,
1253}
1254
1255impl NGBRegressor {
1256    pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
1257        Self {
1258            model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
1259        }
1260    }
1261
1262    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1263        self.model.fit(x, y)
1264    }
1265
1266    pub fn fit_with_validation(
1267        &mut self,
1268        x: &Array2<f64>,
1269        y: &Array1<f64>,
1270        x_val: Option<&Array2<f64>>,
1271        y_val: Option<&Array1<f64>>,
1272    ) -> Result<(), &'static str> {
1273        self.model
1274            .fit_with_validation(x, y, x_val, y_val, None, None)
1275    }
1276
1277    /// Fits an NGBoost model to the data appending base models to the existing ones.
1278    ///
1279    /// NOTE: This method is similar to Python's partial_fit. The first call will be the most
1280    /// significant and later calls will retune the model to newer data.
1281    ///
1282    /// Unlike `fit()`, this method does NOT reset the model state, allowing incremental learning.
1283    pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1284        self.model.partial_fit(x, y)
1285    }
1286
1287    /// Partial fit with validation data support.
1288    pub fn partial_fit_with_validation(
1289        &mut self,
1290        x: &Array2<f64>,
1291        y: &Array1<f64>,
1292        x_val: Option<&Array2<f64>>,
1293        y_val: Option<&Array1<f64>>,
1294    ) -> Result<(), &'static str> {
1295        self.model
1296            .partial_fit_with_validation(x, y, x_val, y_val, None, None)
1297    }
1298
1299    pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1300        self.model.predict(x)
1301    }
1302
1303    /// Get predictions up to a specific iteration
1304    pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1305        self.model.predict_at(x, max_iter)
1306    }
1307
1308    /// Returns an iterator over staged predictions
1309    pub fn staged_predict<'a>(
1310        &'a self,
1311        x: &'a Array2<f64>,
1312    ) -> impl Iterator<Item = Array1<f64>> + 'a {
1313        self.model.staged_predict(x)
1314    }
1315
1316    pub fn pred_dist(&self, x: &Array2<f64>) -> Normal {
1317        self.model.pred_dist(x)
1318    }
1319
1320    /// Get the predicted distribution up to a specific iteration
1321    pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Normal {
1322        self.model.pred_dist_at(x, max_iter)
1323    }
1324
1325    /// Returns an iterator over staged distribution predictions
1326    pub fn staged_pred_dist<'a>(&'a self, x: &'a Array2<f64>) -> impl Iterator<Item = Normal> + 'a {
1327        self.model.staged_pred_dist(x)
1328    }
1329
1330    /// Get the predicted distribution parameters
1331    pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
1332        self.model.pred_param(x)
1333    }
1334
1335    /// Compute the average score (loss) on the given data
1336    pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
1337        self.model.score(x, y)
1338    }
1339
1340    /// Set a custom training loss monitor function
1341    pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
1342    where
1343        F: Fn(&Normal, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1344    {
1345        self.model.set_train_loss_monitor(Box::new(monitor));
1346    }
1347
1348    /// Set a custom validation loss monitor function
1349    pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
1350    where
1351        F: Fn(&Normal, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1352    {
1353        self.model.set_val_loss_monitor(Box::new(monitor));
1354    }
1355
1356    /// Enhanced constructor with all options
1357    pub fn with_options(
1358        n_estimators: u32,
1359        learning_rate: f64,
1360        natural_gradient: bool,
1361        minibatch_frac: f64,
1362        col_sample: f64,
1363        verbose: bool,
1364        verbose_eval: u32,
1365        tol: f64,
1366        early_stopping_rounds: Option<u32>,
1367        validation_fraction: f64,
1368        adaptive_learning_rate: bool,
1369    ) -> Self {
1370        Self {
1371            model: NGBoost::with_options(
1372                n_estimators,
1373                learning_rate,
1374                default_tree_learner(),
1375                natural_gradient,
1376                minibatch_frac,
1377                col_sample,
1378                verbose,
1379                verbose_eval,
1380                tol,
1381                early_stopping_rounds,
1382                validation_fraction,
1383                adaptive_learning_rate,
1384            ),
1385        }
1386    }
1387
1388    /// Enhanced constructor with all options (backward compatible version without adaptive_learning_rate)
1389    pub fn with_options_compat(
1390        n_estimators: u32,
1391        learning_rate: f64,
1392        natural_gradient: bool,
1393        minibatch_frac: f64,
1394        col_sample: f64,
1395        verbose: bool,
1396        verbose_eval: u32,
1397        tol: f64,
1398        early_stopping_rounds: Option<u32>,
1399        validation_fraction: f64,
1400    ) -> Self {
1401        Self::with_options(
1402            n_estimators,
1403            learning_rate,
1404            natural_gradient,
1405            minibatch_frac,
1406            col_sample,
1407            verbose,
1408            verbose_eval,
1409            tol,
1410            early_stopping_rounds,
1411            validation_fraction,
1412            false, // Default adaptive_learning_rate to false
1413        )
1414    }
1415
1416    /// Enable adaptive learning rate for better convergence in probabilistic forecasting
1417    pub fn set_adaptive_learning_rate(&mut self, enabled: bool) {
1418        self.model.adaptive_learning_rate = enabled;
1419    }
1420
1421    /// Calibrate uncertainty estimates using validation data
1422    /// This improves the quality of probabilistic predictions
1423    pub fn calibrate_uncertainty(
1424        &mut self,
1425        x_val: &Array2<f64>,
1426        y_val: &Array1<f64>,
1427    ) -> Result<(), &'static str> {
1428        self.model.calibrate_uncertainty(x_val, y_val)
1429    }
1430
1431    /// Get the number of estimators (boosting iterations)
1432    pub fn n_estimators(&self) -> u32 {
1433        self.model.n_estimators
1434    }
1435
1436    /// Get the learning rate
1437    pub fn learning_rate(&self) -> f64 {
1438        self.model.learning_rate
1439    }
1440
1441    /// Get whether natural gradient is used
1442    pub fn natural_gradient(&self) -> bool {
1443        self.model.natural_gradient
1444    }
1445
1446    /// Get the minibatch fraction
1447    pub fn minibatch_frac(&self) -> f64 {
1448        self.model.minibatch_frac
1449    }
1450
1451    /// Get the column sampling fraction
1452    pub fn col_sample(&self) -> f64 {
1453        self.model.col_sample
1454    }
1455
1456    /// Get the best validation iteration
1457    pub fn best_val_loss_itr(&self) -> Option<usize> {
1458        self.model.best_val_loss_itr
1459    }
1460
1461    /// Get early stopping rounds
1462    pub fn early_stopping_rounds(&self) -> Option<u32> {
1463        self.model.early_stopping_rounds
1464    }
1465
1466    /// Get validation fraction
1467    pub fn validation_fraction(&self) -> f64 {
1468        self.model.validation_fraction
1469    }
1470
1471    /// Get number of features the model was trained on
1472    pub fn n_features(&self) -> Option<usize> {
1473        self.model.n_features()
1474    }
1475
1476    /// Compute feature importances per distribution parameter.
1477    /// Returns a 2D array of shape (n_params, n_features).
1478    pub fn feature_importances(&self) -> Option<Array2<f64>> {
1479        self.model.feature_importances()
1480    }
1481
1482    /// Compute aggregated feature importances across all distribution parameters.
1483    /// Returns a 1D array of length n_features with normalized importances.
1484    pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
1485        self.model.feature_importances_aggregated()
1486    }
1487
1488    /// Save model to file using bincode serialization
1489    pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1490        let serialized = self.model.serialize()?;
1491        let encoded = bincode::serialize(&serialized)?;
1492        std::fs::write(path, encoded)?;
1493        Ok(())
1494    }
1495
1496    /// Load model from file using bincode deserialization
1497    pub fn load_model(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
1498        let encoded = std::fs::read(path)?;
1499        let serialized: SerializedNGBoost = bincode::deserialize(&encoded)?;
1500        let model = NGBoost::<Normal, LogScore, DecisionTreeLearner>::deserialize(
1501            serialized,
1502            default_tree_learner(),
1503        )?;
1504        Ok(Self { model })
1505    }
1506}
1507
1508impl NGBClassifier {
1509    pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
1510        Self {
1511            model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
1512        }
1513    }
1514
1515    pub fn with_options(
1516        n_estimators: u32,
1517        learning_rate: f64,
1518        natural_gradient: bool,
1519        minibatch_frac: f64,
1520        col_sample: f64,
1521        verbose: bool,
1522        verbose_eval: u32,
1523        tol: f64,
1524        early_stopping_rounds: Option<u32>,
1525        validation_fraction: f64,
1526        adaptive_learning_rate: bool,
1527    ) -> Self {
1528        Self {
1529            model: NGBoost::with_options(
1530                n_estimators,
1531                learning_rate,
1532                default_tree_learner(),
1533                natural_gradient,
1534                minibatch_frac,
1535                col_sample,
1536                verbose,
1537                verbose_eval,
1538                tol,
1539                early_stopping_rounds,
1540                validation_fraction,
1541                adaptive_learning_rate,
1542            ),
1543        }
1544    }
1545
1546    /// Set a custom training loss monitor function
1547    pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
1548    where
1549        F: Fn(&Bernoulli, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1550    {
1551        self.model.set_train_loss_monitor(Box::new(monitor));
1552    }
1553
1554    /// Set a custom validation loss monitor function
1555    pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
1556    where
1557        F: Fn(&Bernoulli, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1558    {
1559        self.model.set_val_loss_monitor(Box::new(monitor));
1560    }
1561
1562    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1563        self.model.fit(x, y)
1564    }
1565
1566    pub fn fit_with_validation(
1567        &mut self,
1568        x: &Array2<f64>,
1569        y: &Array1<f64>,
1570        x_val: Option<&Array2<f64>>,
1571        y_val: Option<&Array1<f64>>,
1572    ) -> Result<(), &'static str> {
1573        self.model
1574            .fit_with_validation(x, y, x_val, y_val, None, None)
1575    }
1576
1577    /// Fits an NGBoost model to the data appending base models to the existing ones.
1578    ///
1579    /// NOTE: This method is similar to Python's partial_fit. The first call will be the most
1580    /// significant and later calls will retune the model to newer data.
1581    ///
1582    /// Unlike `fit()`, this method does NOT reset the model state, allowing incremental learning.
1583    pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1584        self.model.partial_fit(x, y)
1585    }
1586
1587    /// Partial fit with validation data support.
1588    pub fn partial_fit_with_validation(
1589        &mut self,
1590        x: &Array2<f64>,
1591        y: &Array1<f64>,
1592        x_val: Option<&Array2<f64>>,
1593        y_val: Option<&Array1<f64>>,
1594    ) -> Result<(), &'static str> {
1595        self.model
1596            .partial_fit_with_validation(x, y, x_val, y_val, None, None)
1597    }
1598
1599    pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1600        self.model.predict(x)
1601    }
1602
1603    /// Get predictions up to a specific iteration
1604    pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1605        self.model.predict_at(x, max_iter)
1606    }
1607
1608    /// Returns an iterator over staged predictions
1609    pub fn staged_predict<'a>(
1610        &'a self,
1611        x: &'a Array2<f64>,
1612    ) -> impl Iterator<Item = Array1<f64>> + 'a {
1613        self.model.staged_predict(x)
1614    }
1615
1616    pub fn predict_proba(&self, x: &Array2<f64>) -> Array2<f64> {
1617        let dist = self.model.pred_dist(x);
1618        dist.class_probs()
1619    }
1620
1621    /// Get class probabilities up to a specific iteration
1622    pub fn predict_proba_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
1623        let dist = self.model.pred_dist_at(x, max_iter);
1624        dist.class_probs()
1625    }
1626
1627    /// Returns an iterator over staged probability predictions
1628    pub fn staged_predict_proba<'a>(
1629        &'a self,
1630        x: &'a Array2<f64>,
1631    ) -> impl Iterator<Item = Array2<f64>> + 'a {
1632        (1..=self.model.base_models.len()).map(move |i| self.predict_proba_at(x, i))
1633    }
1634
1635    pub fn pred_dist(&self, x: &Array2<f64>) -> Bernoulli {
1636        self.model.pred_dist(x)
1637    }
1638
1639    /// Get the predicted distribution up to a specific iteration
1640    pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Bernoulli {
1641        self.model.pred_dist_at(x, max_iter)
1642    }
1643
1644    /// Returns an iterator over staged distribution predictions
1645    pub fn staged_pred_dist<'a>(
1646        &'a self,
1647        x: &'a Array2<f64>,
1648    ) -> impl Iterator<Item = Bernoulli> + 'a {
1649        self.model.staged_pred_dist(x)
1650    }
1651
1652    /// Get the predicted distribution parameters
1653    pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
1654        self.model.pred_param(x)
1655    }
1656
1657    /// Compute the average score (loss) on the given data
1658    pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
1659        self.model.score(x, y)
1660    }
1661
1662    /// Get the number of estimators (boosting iterations)
1663    pub fn n_estimators(&self) -> u32 {
1664        self.model.n_estimators
1665    }
1666
1667    /// Get the learning rate
1668    pub fn learning_rate(&self) -> f64 {
1669        self.model.learning_rate
1670    }
1671
1672    /// Get whether natural gradient is used
1673    pub fn natural_gradient(&self) -> bool {
1674        self.model.natural_gradient
1675    }
1676
1677    /// Get the minibatch fraction
1678    pub fn minibatch_frac(&self) -> f64 {
1679        self.model.minibatch_frac
1680    }
1681
1682    /// Get the column sampling fraction
1683    pub fn col_sample(&self) -> f64 {
1684        self.model.col_sample
1685    }
1686
1687    /// Get the best validation iteration
1688    pub fn best_val_loss_itr(&self) -> Option<usize> {
1689        self.model.best_val_loss_itr
1690    }
1691
1692    /// Get early stopping rounds
1693    pub fn early_stopping_rounds(&self) -> Option<u32> {
1694        self.model.early_stopping_rounds
1695    }
1696
1697    /// Get validation fraction
1698    pub fn validation_fraction(&self) -> f64 {
1699        self.model.validation_fraction
1700    }
1701
1702    /// Get number of features the model was trained on
1703    pub fn n_features(&self) -> Option<usize> {
1704        self.model.n_features()
1705    }
1706
1707    /// Compute feature importances per distribution parameter.
1708    /// Returns a 2D array of shape (n_params, n_features).
1709    pub fn feature_importances(&self) -> Option<Array2<f64>> {
1710        self.model.feature_importances()
1711    }
1712
1713    /// Compute aggregated feature importances across all distribution parameters.
1714    /// Returns a 1D array of length n_features with normalized importances.
1715    pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
1716        self.model.feature_importances_aggregated()
1717    }
1718
1719    /// Save model to file using bincode serialization
1720    pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1721        let serialized = self.model.serialize()?;
1722        let encoded = bincode::serialize(&serialized)?;
1723        std::fs::write(path, encoded)?;
1724        Ok(())
1725    }
1726
1727    /// Load model from file using bincode deserialization
1728    pub fn load_model(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
1729        let encoded = std::fs::read(path)?;
1730        let serialized: SerializedNGBoost = bincode::deserialize(&encoded)?;
1731        let model = NGBoost::<Bernoulli, LogScore, DecisionTreeLearner>::deserialize(
1732            serialized,
1733            default_tree_learner(),
1734        )?;
1735        Ok(Self { model })
1736    }
1737}