ngboost_rs/
ngboost.rs

1/// Type alias for loss monitor functions
2pub type LossMonitor<D> = Box<dyn Fn(&D, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync>;
3
4use crate::dist::categorical::{Bernoulli, Categorical};
5use crate::dist::normal::Normal;
6use crate::dist::{ClassificationDistn, Distribution};
7use crate::learners::{default_tree_learner, BaseLearner, DecisionTreeLearner, TrainedBaseLearner};
8use crate::scores::{LogScore, Scorable, Score};
9use ndarray::{Array1, Array2};
10use rand::prelude::*;
11use rand::rngs::StdRng;
12use rand::SeedableRng;
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18/// Learning rate schedule for controlling step size during training.
19#[derive(Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize)]
20pub enum LearningRateSchedule {
21    /// Constant learning rate throughout training.
22    #[default]
23    Constant,
24    /// Linear decay: lr * (1 - decay_rate * progress), clamped to min_lr.
25    /// Default: decay_rate=0.7, min_lr=0.1
26    Linear {
27        decay_rate: f64,
28        min_lr_fraction: f64,
29    },
30    /// Exponential decay: lr * exp(-decay_rate * progress).
31    Exponential { decay_rate: f64 },
32    /// Cosine annealing: lr * 0.5 * (1 + cos(pi * progress)).
33    /// Proven effective for probabilistic models.
34    Cosine,
35    /// Cosine annealing with warm restarts.
36    /// Restarts the schedule every `restart_period` iterations.
37    CosineWarmRestarts { restart_period: u32 },
38}
39
40/// Line search method for finding optimal step size.
41#[derive(Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize)]
42pub enum LineSearchMethod {
43    /// Binary search (original NGBoost method): scale up then scale down by 2x.
44    /// Fast but may miss optimal step size.
45    #[default]
46    Binary,
47    /// Golden section search: more accurate but slightly slower.
48    /// Uses the golden ratio to efficiently narrow down the optimal step size.
49    /// Generally finds better step sizes with fewer function evaluations.
50    GoldenSection {
51        /// Maximum number of iterations (default: 20)
52        max_iters: usize,
53    },
54}
55
56/// Golden ratio constant for golden section search.
57const GOLDEN_RATIO: f64 = 1.618033988749895;
58
59/// Training history result containing loss values at each iteration.
60#[derive(Debug, Clone, Default)]
61pub struct EvalsResult {
62    /// Training loss at each iteration.
63    pub train: Vec<f64>,
64    /// Validation loss at each iteration (if validation data provided).
65    pub val: Vec<f64>,
66}
67
68/// Hyperparameters for NGBoost models.
69/// Used for `get_params()` and `set_params()` methods (sklearn-style interface).
70#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
71pub struct NGBoostParams {
72    /// Number of boosting iterations.
73    pub n_estimators: u32,
74    /// Learning rate (step size shrinkage).
75    pub learning_rate: f64,
76    /// Whether to use natural gradient (recommended).
77    pub natural_gradient: bool,
78    /// Fraction of samples to use per iteration (1.0 = all samples).
79    pub minibatch_frac: f64,
80    /// Fraction of features to use per iteration (1.0 = all features).
81    pub col_sample: f64,
82    /// Whether to print training progress.
83    pub verbose: bool,
84    /// Verbose evaluation interval.
85    pub verbose_eval: f64,
86    /// Tolerance for early stopping based on loss improvement.
87    pub tol: f64,
88    /// Number of rounds without improvement before stopping.
89    pub early_stopping_rounds: Option<u32>,
90    /// Fraction of training data to use for validation when no explicit validation set provided.
91    pub validation_fraction: f64,
92    /// Random seed for reproducibility.
93    pub random_state: Option<u64>,
94    /// Learning rate schedule.
95    pub lr_schedule: LearningRateSchedule,
96    /// Tikhonov regularization for Fisher matrix.
97    pub tikhonov_reg: f64,
98    /// Line search method.
99    pub line_search_method: LineSearchMethod,
100}
101
102impl Default for NGBoostParams {
103    fn default() -> Self {
104        Self {
105            n_estimators: 500,
106            learning_rate: 0.01,
107            natural_gradient: true,
108            minibatch_frac: 1.0,
109            col_sample: 1.0,
110            verbose: false,
111            verbose_eval: 1.0,
112            tol: 1e-4,
113            early_stopping_rounds: None,
114            validation_fraction: 0.1,
115            random_state: None,
116            lr_schedule: LearningRateSchedule::Constant,
117            tikhonov_reg: 0.0,
118            line_search_method: LineSearchMethod::Binary,
119        }
120    }
121}
122
123pub struct NGBoost<D, S, B>
124where
125    D: Distribution + Scorable<S> + Clone,
126    S: Score,
127    B: BaseLearner + Clone,
128{
129    // Hyperparameters
130    pub n_estimators: u32,
131    pub learning_rate: f64,
132    pub natural_gradient: bool,
133    pub minibatch_frac: f64,
134    pub col_sample: f64,
135    pub verbose: bool,
136    /// Interval for verbose output during training.
137    /// - If >= 1.0: print every `verbose_eval` iterations (e.g., 100 means every 100 iterations)
138    /// - If < 1.0 and > 0.0: print every `verbose_eval * n_estimators` iterations (e.g., 0.1 means every 10%)
139    /// - If <= 0.0: no verbose output regardless of `verbose` setting
140    pub verbose_eval: f64,
141    pub tol: f64,
142    pub early_stopping_rounds: Option<u32>,
143    pub validation_fraction: f64,
144    pub adaptive_learning_rate: bool, // Enable adaptive learning rate for better convergence (deprecated, use lr_schedule)
145    /// Learning rate schedule for controlling step size during training.
146    pub lr_schedule: LearningRateSchedule,
147    /// Tikhonov regularization parameter for stabilizing Fisher matrix inversion.
148    /// Added to diagonal of Fisher Information Matrix: F + tikhonov_reg * I.
149    /// Set to 0.0 to disable (default). Typical values: 1e-6 to 1e-3.
150    pub tikhonov_reg: f64,
151    /// Line search method for finding optimal step size.
152    pub line_search_method: LineSearchMethod,
153
154    // Base learner
155    base_learner: B,
156
157    // State
158    pub base_models: Vec<Vec<Box<dyn TrainedBaseLearner>>>,
159    pub scalings: Vec<f64>,
160    pub init_params: Option<Array1<f64>>,
161    pub col_idxs: Vec<Vec<usize>>,
162    train_loss_monitor: Option<LossMonitor<D>>,
163    val_loss_monitor: Option<LossMonitor<D>>,
164    best_val_loss_itr: Option<usize>,
165    n_features: Option<usize>,
166    /// Training history containing loss values at each iteration.
167    pub evals_result: EvalsResult,
168
169    // Random number generator (seeded for reproducibility)
170    rng: StdRng,
171    /// Optional random seed for reproducibility. If None, uses entropy from the OS.
172    random_state: Option<u64>,
173
174    // Generics
175    _dist: PhantomData<D>,
176    _score: PhantomData<S>,
177}
178
179impl<D, S, B> NGBoost<D, S, B>
180where
181    D: Distribution + Scorable<S> + Clone,
182    S: Score,
183    B: BaseLearner + Clone,
184{
185    pub fn new(n_estimators: u32, learning_rate: f64, base_learner: B) -> Self {
186        NGBoost {
187            n_estimators,
188            learning_rate,
189            natural_gradient: true,
190            minibatch_frac: 1.0,
191            col_sample: 1.0,
192            verbose: false,
193            verbose_eval: 100.0,
194            tol: 1e-4,
195            early_stopping_rounds: None,
196            validation_fraction: 0.1,
197            adaptive_learning_rate: false,
198            lr_schedule: LearningRateSchedule::Constant,
199            tikhonov_reg: 0.0,
200            line_search_method: LineSearchMethod::Binary,
201            base_learner,
202            base_models: Vec::new(),
203            scalings: Vec::new(),
204            init_params: None,
205            col_idxs: Vec::new(),
206            train_loss_monitor: None,
207            val_loss_monitor: None,
208            best_val_loss_itr: None,
209            n_features: None,
210            evals_result: EvalsResult::default(),
211            rng: StdRng::from_rng(&mut rand::rng()),
212            random_state: None,
213            _dist: PhantomData,
214            _score: PhantomData,
215        }
216    }
217
218    /// Create a new NGBoost model with a specific random seed for reproducibility.
219    /// This is equivalent to Python's `random_state` parameter.
220    pub fn with_seed(n_estimators: u32, learning_rate: f64, base_learner: B, seed: u64) -> Self {
221        NGBoost {
222            n_estimators,
223            learning_rate,
224            natural_gradient: true,
225            minibatch_frac: 1.0,
226            col_sample: 1.0,
227            verbose: false,
228            verbose_eval: 100.0,
229            tol: 1e-4,
230            early_stopping_rounds: None,
231            validation_fraction: 0.1,
232            adaptive_learning_rate: false,
233            lr_schedule: LearningRateSchedule::Constant,
234            tikhonov_reg: 0.0,
235            line_search_method: LineSearchMethod::Binary,
236            base_learner,
237            base_models: Vec::new(),
238            scalings: Vec::new(),
239            init_params: None,
240            col_idxs: Vec::new(),
241            train_loss_monitor: None,
242            val_loss_monitor: None,
243            best_val_loss_itr: None,
244            n_features: None,
245            evals_result: EvalsResult::default(),
246            rng: StdRng::seed_from_u64(seed),
247            random_state: Some(seed),
248            _dist: PhantomData,
249            _score: PhantomData,
250        }
251    }
252
253    /// Set the random seed for reproducibility.
254    /// Call this before `fit()` to ensure reproducible results.
255    pub fn set_random_state(&mut self, seed: u64) {
256        self.random_state = Some(seed);
257        self.rng = StdRng::seed_from_u64(seed);
258    }
259
260    /// Get the current random state (seed), if set.
261    pub fn random_state(&self) -> Option<u64> {
262        self.random_state
263    }
264
265    /// Returns a reference to the training history.
266    pub fn evals_result(&self) -> &EvalsResult {
267        &self.evals_result
268    }
269
270    pub fn with_options(
271        n_estimators: u32,
272        learning_rate: f64,
273        base_learner: B,
274        natural_gradient: bool,
275        minibatch_frac: f64,
276        col_sample: f64,
277        verbose: bool,
278        verbose_eval: f64,
279        tol: f64,
280        early_stopping_rounds: Option<u32>,
281        validation_fraction: f64,
282        adaptive_learning_rate: bool,
283    ) -> Self {
284        NGBoost {
285            n_estimators,
286            learning_rate,
287            natural_gradient,
288            minibatch_frac,
289            col_sample,
290            verbose,
291            verbose_eval,
292            tol,
293            early_stopping_rounds,
294            validation_fraction,
295            adaptive_learning_rate,
296            lr_schedule: LearningRateSchedule::Constant,
297            tikhonov_reg: 0.0,
298            line_search_method: LineSearchMethod::Binary,
299            base_learner,
300            base_models: Vec::new(),
301            scalings: Vec::new(),
302            init_params: None,
303            col_idxs: Vec::new(),
304            train_loss_monitor: None,
305            val_loss_monitor: None,
306            best_val_loss_itr: None,
307            n_features: None,
308            evals_result: EvalsResult::default(),
309            rng: StdRng::from_rng(&mut rand::rng()),
310            random_state: None,
311            _dist: PhantomData,
312            _score: PhantomData,
313        }
314    }
315
316    /// Create NGBoost with all options including random seed for reproducibility.
317    #[allow(clippy::too_many_arguments)]
318    pub fn with_options_seeded(
319        n_estimators: u32,
320        learning_rate: f64,
321        base_learner: B,
322        natural_gradient: bool,
323        minibatch_frac: f64,
324        col_sample: f64,
325        verbose: bool,
326        verbose_eval: f64,
327        tol: f64,
328        early_stopping_rounds: Option<u32>,
329        validation_fraction: f64,
330        adaptive_learning_rate: bool,
331        random_state: Option<u64>,
332    ) -> Self {
333        let rng = match random_state {
334            Some(seed) => StdRng::seed_from_u64(seed),
335            None => StdRng::from_rng(&mut rand::rng()),
336        };
337        NGBoost {
338            n_estimators,
339            learning_rate,
340            natural_gradient,
341            minibatch_frac,
342            col_sample,
343            verbose,
344            verbose_eval,
345            tol,
346            early_stopping_rounds,
347            validation_fraction,
348            adaptive_learning_rate,
349            lr_schedule: LearningRateSchedule::Constant,
350            tikhonov_reg: 0.0,
351            line_search_method: LineSearchMethod::Binary,
352            base_learner,
353            base_models: Vec::new(),
354            scalings: Vec::new(),
355            init_params: None,
356            col_idxs: Vec::new(),
357            train_loss_monitor: None,
358            val_loss_monitor: None,
359            best_val_loss_itr: None,
360            n_features: None,
361            evals_result: EvalsResult::default(),
362            rng,
363            random_state,
364            _dist: PhantomData,
365            _score: PhantomData,
366        }
367    }
368
369    /// Create NGBoost with advanced options including learning rate schedule and regularization.
370    #[allow(clippy::too_many_arguments)]
371    pub fn with_advanced_options(
372        n_estimators: u32,
373        learning_rate: f64,
374        base_learner: B,
375        natural_gradient: bool,
376        minibatch_frac: f64,
377        col_sample: f64,
378        verbose: bool,
379        verbose_eval: f64,
380        tol: f64,
381        early_stopping_rounds: Option<u32>,
382        validation_fraction: f64,
383        lr_schedule: LearningRateSchedule,
384        tikhonov_reg: f64,
385        line_search_method: LineSearchMethod,
386    ) -> Self {
387        NGBoost {
388            n_estimators,
389            learning_rate,
390            natural_gradient,
391            minibatch_frac,
392            col_sample,
393            verbose,
394            verbose_eval,
395            tol,
396            early_stopping_rounds,
397            validation_fraction,
398            adaptive_learning_rate: false,
399            lr_schedule,
400            tikhonov_reg,
401            line_search_method,
402            base_learner,
403            base_models: Vec::new(),
404            scalings: Vec::new(),
405            init_params: None,
406            col_idxs: Vec::new(),
407            train_loss_monitor: None,
408            val_loss_monitor: None,
409            best_val_loss_itr: None,
410            n_features: None,
411            evals_result: EvalsResult::default(),
412            rng: StdRng::from_rng(&mut rand::rng()),
413            random_state: None,
414            _dist: PhantomData,
415            _score: PhantomData,
416        }
417    }
418
419    /// Create NGBoost with all advanced options including random seed for reproducibility.
420    #[allow(clippy::too_many_arguments)]
421    pub fn with_full_options(
422        n_estimators: u32,
423        learning_rate: f64,
424        base_learner: B,
425        natural_gradient: bool,
426        minibatch_frac: f64,
427        col_sample: f64,
428        verbose: bool,
429        verbose_eval: f64,
430        tol: f64,
431        early_stopping_rounds: Option<u32>,
432        validation_fraction: f64,
433        lr_schedule: LearningRateSchedule,
434        tikhonov_reg: f64,
435        line_search_method: LineSearchMethod,
436        random_state: Option<u64>,
437    ) -> Self {
438        let rng = match random_state {
439            Some(seed) => StdRng::seed_from_u64(seed),
440            None => StdRng::from_rng(&mut rand::rng()),
441        };
442        NGBoost {
443            n_estimators,
444            learning_rate,
445            natural_gradient,
446            minibatch_frac,
447            col_sample,
448            verbose,
449            verbose_eval,
450            tol,
451            early_stopping_rounds,
452            validation_fraction,
453            adaptive_learning_rate: false,
454            lr_schedule,
455            tikhonov_reg,
456            line_search_method,
457            base_learner,
458            base_models: Vec::new(),
459            scalings: Vec::new(),
460            init_params: None,
461            col_idxs: Vec::new(),
462            train_loss_monitor: None,
463            val_loss_monitor: None,
464            best_val_loss_itr: None,
465            n_features: None,
466            evals_result: EvalsResult::default(),
467            rng,
468            random_state,
469            _dist: PhantomData,
470            _score: PhantomData,
471        }
472    }
473
474    /// Set a custom training loss monitor function
475    pub fn set_train_loss_monitor(&mut self, monitor: LossMonitor<D>) {
476        self.train_loss_monitor = Some(monitor);
477    }
478
479    /// Set a custom validation loss monitor function
480    pub fn set_val_loss_monitor(&mut self, monitor: LossMonitor<D>) {
481        self.val_loss_monitor = Some(monitor);
482    }
483
484    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
485        self.fit_with_validation(x, y, None, None, None, None)
486    }
487
488    /// Fits an NGBoost model to the data appending base models to the existing ones.
489    ///
490    /// NOTE: This method is similar to Python's partial_fit. The first call will be the most
491    /// significant and later calls will retune the model to newer data.
492    ///
493    /// Unlike `fit()`, this method does NOT reset the model state, allowing incremental learning.
494    pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
495        self.partial_fit_with_validation(x, y, None, None, None, None)
496    }
497
498    /// Partial fit with validation data support.
499    pub fn partial_fit_with_validation(
500        &mut self,
501        x: &Array2<f64>,
502        y: &Array1<f64>,
503        x_val: Option<&Array2<f64>>,
504        y_val: Option<&Array1<f64>>,
505        sample_weight: Option<&Array1<f64>>,
506        val_sample_weight: Option<&Array1<f64>>,
507    ) -> Result<(), &'static str> {
508        // Don't reset state - this is the key difference from fit()
509        self.fit_internal(x, y, x_val, y_val, sample_weight, val_sample_weight, false)
510    }
511
512    pub fn fit_with_validation(
513        &mut self,
514        x: &Array2<f64>,
515        y: &Array1<f64>,
516        x_val: Option<&Array2<f64>>,
517        y_val: Option<&Array1<f64>>,
518        sample_weight: Option<&Array1<f64>>,
519        val_sample_weight: Option<&Array1<f64>>,
520    ) -> Result<(), &'static str> {
521        self.fit_internal(x, y, x_val, y_val, sample_weight, val_sample_weight, true)
522    }
523
524    /// Validates hyperparameters before fitting.
525    /// Returns an error message if any hyperparameter is invalid.
526    fn validate_hyperparameters(&self) -> Result<(), &'static str> {
527        if self.n_estimators == 0 {
528            return Err("n_estimators must be greater than 0");
529        }
530        if self.learning_rate <= 0.0 {
531            return Err("learning_rate must be positive");
532        }
533        if self.learning_rate > 10.0 {
534            return Err("learning_rate > 10.0 is likely a mistake");
535        }
536        if self.minibatch_frac <= 0.0 || self.minibatch_frac > 1.0 {
537            return Err("minibatch_frac must be in (0, 1]");
538        }
539        if self.col_sample <= 0.0 || self.col_sample > 1.0 {
540            return Err("col_sample must be in (0, 1]");
541        }
542        if self.tol < 0.0 {
543            return Err("tol must be non-negative");
544        }
545        if self.validation_fraction < 0.0 || self.validation_fraction >= 1.0 {
546            return Err("validation_fraction must be in [0, 1)");
547        }
548        if self.tikhonov_reg < 0.0 {
549            return Err("tikhonov_reg must be non-negative");
550        }
551
552        // Validate learning rate schedule parameters
553        match self.lr_schedule {
554            LearningRateSchedule::Linear {
555                decay_rate,
556                min_lr_fraction,
557            } => {
558                if decay_rate < 0.0 || decay_rate > 1.0 {
559                    return Err("Linear schedule decay_rate must be in [0, 1]");
560                }
561                if min_lr_fraction < 0.0 || min_lr_fraction > 1.0 {
562                    return Err("Linear schedule min_lr_fraction must be in [0, 1]");
563                }
564            }
565            LearningRateSchedule::Exponential { decay_rate } => {
566                if decay_rate < 0.0 {
567                    return Err("Exponential schedule decay_rate must be non-negative");
568                }
569            }
570            LearningRateSchedule::CosineWarmRestarts { restart_period } => {
571                if restart_period == 0 {
572                    return Err("CosineWarmRestarts restart_period must be > 0");
573                }
574            }
575            _ => {}
576        }
577
578        // Validate line search parameters
579        if let LineSearchMethod::GoldenSection { max_iters } = self.line_search_method {
580            if max_iters == 0 {
581                return Err("GoldenSection max_iters must be > 0");
582            }
583        }
584
585        Ok(())
586    }
587
588    /// Internal fit implementation that can optionally reset state.
589    fn fit_internal(
590        &mut self,
591        x: &Array2<f64>,
592        y: &Array1<f64>,
593        x_val: Option<&Array2<f64>>,
594        y_val: Option<&Array1<f64>>,
595        sample_weight: Option<&Array1<f64>>,
596        val_sample_weight: Option<&Array1<f64>>,
597        reset_state: bool,
598    ) -> Result<(), &'static str> {
599        // Validate hyperparameters first
600        self.validate_hyperparameters()?;
601
602        // Validate input dimensions with more detailed error messages
603        if x.nrows() != y.len() {
604            return Err("Number of samples in X and y must match");
605        }
606        if x.nrows() == 0 {
607            return Err("Cannot fit to empty dataset");
608        }
609        if x.ncols() == 0 {
610            return Err("Cannot fit to dataset with no features");
611        }
612
613        // Check for NaN/Inf values in input data
614        if x.iter().any(|&v| !v.is_finite()) {
615            return Err("Input X contains NaN or infinite values");
616        }
617        if y.iter().any(|&v| !v.is_finite()) {
618            return Err("Input y contains NaN or infinite values");
619        }
620
621        // Reset state only if requested (fit() resets, partial_fit() doesn't)
622        if reset_state {
623            self.base_models.clear();
624            self.scalings.clear();
625            self.col_idxs.clear();
626            self.best_val_loss_itr = None;
627            self.evals_result = EvalsResult::default();
628        }
629        self.n_features = Some(x.ncols());
630
631        // Handle automatic validation split if early stopping is enabled
632        let (x_train, y_train, x_val_auto, y_val_auto) = if self.early_stopping_rounds.is_some()
633            && x_val.is_none()
634            && y_val.is_none()
635            && self.validation_fraction > 0.0
636            && self.validation_fraction < 1.0
637        {
638            // Split training data into training and validation sets
639            // Shuffle indices first to match sklearn's train_test_split behavior
640            let n_samples = x.nrows();
641            let n_val = ((n_samples as f64) * self.validation_fraction) as usize;
642            let n_train = n_samples - n_val;
643
644            // Shuffle indices for random split (matches Python's train_test_split)
645            let mut indices: Vec<usize> = (0..n_samples).collect();
646            for i in (1..indices.len()).rev() {
647                let j = self.rng.random_range(0..=i);
648                indices.swap(i, j);
649            }
650
651            let train_indices: Vec<usize> = indices[0..n_train].to_vec();
652            let val_indices: Vec<usize> = indices[n_train..].to_vec();
653
654            let x_train = x.select(ndarray::Axis(0), &train_indices);
655            let y_train = y.select(ndarray::Axis(0), &train_indices);
656            let x_val_auto = Some(x.select(ndarray::Axis(0), &val_indices));
657            let y_val_auto = Some(y.select(ndarray::Axis(0), &val_indices));
658
659            (x_train, y_train, x_val_auto, y_val_auto)
660        } else {
661            (x.to_owned(), y.to_owned(), x_val.cloned(), y_val.cloned())
662        };
663
664        // Use the automatically split or provided validation data
665        let x_train = x_train;
666        let y_train = y_train;
667        let x_val = x_val_auto.as_ref().or(x_val);
668        let y_val = y_val_auto.as_ref().or(y_val);
669
670        // Validate validation data if provided
671        if let (Some(xv), Some(yv)) = (x_val, y_val) {
672            if xv.nrows() != yv.len() {
673                return Err("Number of samples in validation X and y must match");
674            }
675            if xv.ncols() != x_train.ncols() {
676                return Err("Number of features in training and validation data must match");
677            }
678        }
679
680        self.init_params = Some(D::fit(&y_train));
681        let n_params = self.init_params.as_ref().unwrap().len();
682        let mut params = Array2::from_elem((x_train.nrows(), n_params), 0.0);
683
684        // Safe unwrap with proper error handling
685        let init_params = self.init_params.as_ref().unwrap();
686        params
687            .outer_iter_mut()
688            .for_each(|mut row| row.assign(init_params));
689
690        // Prepare validation params if validation data is provided
691        let mut val_params = if let (Some(xv), Some(_yv)) = (x_val, y_val) {
692            let mut v_params = Array2::from_elem((xv.nrows(), n_params), 0.0);
693            v_params
694                .outer_iter_mut()
695                .for_each(|mut row| row.assign(init_params));
696            Some(v_params)
697        } else {
698            None
699        };
700
701        let mut best_val_loss = f64::INFINITY;
702        let mut best_iter = 0;
703        let mut no_improvement_count = 0;
704
705        for itr in 0..self.n_estimators {
706            let dist = D::from_params(&params);
707
708            // Compute gradients with optional Tikhonov regularization
709            let grads = if self.natural_gradient && self.tikhonov_reg > 0.0 {
710                // Use regularized natural gradient for better numerical stability
711                let standard_grad = Scorable::d_score(&dist, &y_train);
712                let metric = Scorable::metric(&dist);
713                crate::scores::natural_gradient_regularized(
714                    &standard_grad,
715                    &metric,
716                    self.tikhonov_reg,
717                )
718            } else {
719                Scorable::grad(&dist, &y_train, self.natural_gradient)
720            };
721
722            // Sample data for this iteration
723            let (row_idxs, col_idxs, x_sampled, y_sampled, params_sampled, weight_sampled) =
724                self.sample(&x_train, &y_train, &params, sample_weight);
725            self.col_idxs.push(col_idxs.clone());
726
727            let grads_sampled = grads.select(ndarray::Axis(0), &row_idxs);
728
729            // Fit base learners for each parameter - parallelized when feature enabled
730            #[cfg(feature = "parallel")]
731            let fit_results: Vec<
732                Result<(Box<dyn TrainedBaseLearner>, Array1<f64>), &'static str>,
733            > = {
734                // Pre-clone learners to avoid borrow issues in parallel iterator
735                let learners: Vec<B> = (0..n_params).map(|_| self.base_learner.clone()).collect();
736                learners
737                    .into_par_iter()
738                    .enumerate()
739                    .map(|(j, learner)| {
740                        let grad_j = grads_sampled.column(j).to_owned();
741                        let fitted = learner.fit_with_weights(
742                            &x_sampled,
743                            &grad_j,
744                            weight_sampled.as_ref(),
745                        )?;
746                        let preds = fitted.predict(&x_sampled);
747                        Ok((fitted, preds))
748                    })
749                    .collect()
750            };
751
752            #[cfg(not(feature = "parallel"))]
753            let fit_results: Vec<
754                Result<(Box<dyn TrainedBaseLearner>, Array1<f64>), &'static str>,
755            > = (0..n_params)
756                .map(|j| {
757                    let grad_j = grads_sampled.column(j).to_owned();
758                    let learner = self.base_learner.clone();
759                    let fitted =
760                        learner.fit_with_weights(&x_sampled, &grad_j, weight_sampled.as_ref())?;
761                    let preds = fitted.predict(&x_sampled);
762                    Ok((fitted, preds))
763                })
764                .collect();
765
766            // Unpack results, propagating any errors
767            let mut fitted_learners: Vec<Box<dyn TrainedBaseLearner>> =
768                Vec::with_capacity(n_params);
769            let mut predictions_cols: Vec<Array1<f64>> = Vec::with_capacity(n_params);
770            for result in fit_results {
771                let (fitted, preds) = result?;
772                fitted_learners.push(fitted);
773                predictions_cols.push(preds);
774            }
775
776            let predictions = to_2d_array(predictions_cols);
777
778            let scale = self.line_search(
779                &predictions,
780                &params_sampled,
781                &y_sampled,
782                weight_sampled.as_ref(),
783            );
784            self.scalings.push(scale);
785            self.base_models.push(fitted_learners);
786
787            // Apply learning rate schedule
788            let progress = itr as f64 / self.n_estimators as f64;
789            let effective_learning_rate = self.compute_learning_rate(itr, progress);
790
791            // Update parameters for ALL training samples by re-predicting on full X
792            // This matches Python's behavior: after fitting base learners on minibatch,
793            // we predict on the FULL training set to update all parameters
794            // This is critical for correct convergence with minibatch_frac < 1.0
795            let fitted_learners = self.base_models.last().unwrap();
796            let full_predictions_cols: Vec<Array1<f64>> = if col_idxs.len() == x_train.ncols() {
797                fitted_learners
798                    .iter()
799                    .map(|learner| learner.predict(&x_train))
800                    .collect()
801            } else {
802                let x_subset = x_train.select(ndarray::Axis(1), &col_idxs);
803                fitted_learners
804                    .iter()
805                    .map(|learner| learner.predict(&x_subset))
806                    .collect()
807            };
808            let full_predictions = to_2d_array(full_predictions_cols);
809
810            params -= &(effective_learning_rate * scale * &full_predictions);
811
812            // Update validation parameters if validation data is provided
813            if let (Some(xv), Some(yv), Some(vp)) = (x_val, y_val, val_params.as_mut()) {
814                // Get predictions on validation data from the fitted base learners
815                // Apply column subsampling to match training
816                let fitted_learners = self.base_models.last().unwrap();
817                let val_predictions_cols: Vec<Array1<f64>> = if col_idxs.len() == xv.ncols() {
818                    fitted_learners
819                        .iter()
820                        .map(|learner| learner.predict(xv))
821                        .collect()
822                } else {
823                    let xv_subset = xv.select(ndarray::Axis(1), &col_idxs);
824                    fitted_learners
825                        .iter()
826                        .map(|learner| learner.predict(&xv_subset))
827                        .collect()
828                };
829                let val_predictions = to_2d_array(val_predictions_cols);
830                *vp -= &(effective_learning_rate * scale * &val_predictions);
831
832                // Calculate validation loss using monitor or default
833                let val_dist = D::from_params(vp);
834                let val_loss = if let Some(monitor) = &self.val_loss_monitor {
835                    monitor(&val_dist, yv, val_sample_weight)
836                } else {
837                    Scorable::total_score(&val_dist, yv, val_sample_weight)
838                };
839
840                // Track validation loss in evals_result
841                self.evals_result.val.push(val_loss);
842
843                // Early stopping logic
844                if val_loss < best_val_loss {
845                    best_val_loss = val_loss;
846                    best_iter = itr;
847                    no_improvement_count = 0;
848                    self.best_val_loss_itr = Some(itr as usize);
849                } else {
850                    no_improvement_count += 1;
851                }
852
853                // Check if we should stop early
854                if let Some(rounds) = self.early_stopping_rounds {
855                    if no_improvement_count >= rounds {
856                        if self.verbose {
857                            println!("== Early stopping achieved.");
858                            println!(
859                                "== Best iteration / VAL{} (val_loss={:.4})",
860                                best_iter, best_val_loss
861                            );
862                        }
863                        break;
864                    }
865                }
866
867                // Calculate and track training loss
868                let dist = D::from_params(&params);
869                let train_loss = if let Some(monitor) = &self.train_loss_monitor {
870                    monitor(&dist, &y_train, sample_weight)
871                } else {
872                    Scorable::total_score(&dist, &y_train, sample_weight)
873                };
874                self.evals_result.train.push(train_loss);
875
876                // Verbose logging with validation
877                if self.should_print_verbose(itr) {
878                    println!(
879                        "[iter {}] train_loss={:.4} val_loss={:.4}",
880                        itr, train_loss, val_loss
881                    );
882                }
883            } else {
884                // Calculate and track training loss
885                let dist = D::from_params(&params);
886                let train_loss = if let Some(monitor) = &self.train_loss_monitor {
887                    monitor(&dist, &y_train, sample_weight)
888                } else {
889                    Scorable::total_score(&dist, &y_train, sample_weight)
890                };
891                self.evals_result.train.push(train_loss);
892
893                // Verbose logging without validation
894                if self.should_print_verbose(itr) {
895                    // Calculate gradient norm for debugging
896                    let grad_norm: f64 =
897                        grads.iter().map(|x| x * x).sum::<f64>().sqrt() / grads.len() as f64;
898
899                    println!(
900                        "[iter {}] loss={:.4} grad_norm={:.4} scale={:.4}",
901                        itr, train_loss, grad_norm, scale
902                    );
903                }
904            }
905        }
906
907        Ok(())
908    }
909
910    fn sample(
911        &mut self,
912        x: &Array2<f64>,
913        y: &Array1<f64>,
914        params: &Array2<f64>,
915        sample_weight: Option<&Array1<f64>>,
916    ) -> (
917        Vec<usize>,
918        Vec<usize>,
919        Array2<f64>,
920        Array1<f64>,
921        Array2<f64>,
922        Option<Array1<f64>>,
923    ) {
924        let n_samples = x.nrows();
925        let n_features = x.ncols();
926
927        // Sample rows (minibatch)
928        let sample_size = if self.minibatch_frac >= 1.0 {
929            n_samples
930        } else {
931            ((n_samples as f64) * self.minibatch_frac) as usize
932        };
933
934        // Uniform random sampling without replacement (matches Python's np.random.choice behavior)
935        // Note: Python does NOT do weighted sampling for minibatch selection,
936        // it only passes the weights to the base learner's fit method
937        let row_idxs: Vec<usize> = if sample_size == n_samples {
938            (0..n_samples).collect()
939        } else {
940            let mut indices: Vec<usize> = (0..n_samples).collect();
941            // Use Fisher-Yates shuffle for better randomness (matches numpy's algorithm)
942            for i in (1..indices.len()).rev() {
943                let j = self.rng.random_range(0..=i);
944                indices.swap(i, j);
945            }
946            indices.into_iter().take(sample_size).collect()
947        };
948
949        // Sample columns
950        let col_size = if self.col_sample >= 1.0 {
951            n_features
952        } else if self.col_sample > 0.0 {
953            ((n_features as f64) * self.col_sample) as usize
954        } else {
955            0
956        };
957
958        let col_idxs: Vec<usize> = if col_size == n_features || col_size == 0 {
959            (0..n_features).collect()
960        } else {
961            let mut indices: Vec<usize> = (0..n_features).collect();
962            indices.shuffle(&mut self.rng);
963            indices.into_iter().take(col_size).collect()
964        };
965
966        // Create sampled data with optimized single-pass selection
967        // Instead of two sequential selects (which create an intermediate array),
968        // we directly construct the result array
969        let x_sampled = if col_size == n_features {
970            // No column sampling - just select rows (single allocation)
971            x.select(ndarray::Axis(0), &row_idxs)
972        } else {
973            // Both row and column sampling - single allocation with direct indexing
974            let mut result = Array2::zeros((row_idxs.len(), col_idxs.len()));
975            for (new_row, &old_row) in row_idxs.iter().enumerate() {
976                for (new_col, &old_col) in col_idxs.iter().enumerate() {
977                    result[[new_row, new_col]] = x[[old_row, old_col]];
978                }
979            }
980            result
981        };
982        let y_sampled = y.select(ndarray::Axis(0), &row_idxs);
983        let params_sampled = params.select(ndarray::Axis(0), &row_idxs);
984
985        // Handle sample weights
986        let sample_weights_sampled =
987            sample_weight.map(|weights| weights.select(ndarray::Axis(0), &row_idxs));
988
989        (
990            row_idxs,
991            col_idxs,
992            x_sampled,
993            y_sampled,
994            params_sampled,
995            sample_weights_sampled,
996        )
997    }
998
999    fn get_params(&self, x: &Array2<f64>) -> Array2<f64> {
1000        self.get_params_at(x, None)
1001    }
1002
1003    fn get_params_at(&self, x: &Array2<f64>, max_iter: Option<usize>) -> Array2<f64> {
1004        if x.nrows() == 0 {
1005            return Array2::zeros((0, 0));
1006        }
1007
1008        let init_params = self
1009            .init_params
1010            .as_ref()
1011            .expect("Model has not been fitted. Call fit() before predict().");
1012        let n_params = init_params.len();
1013        let mut params = Array2::from_elem((x.nrows(), n_params), 0.0);
1014        params
1015            .outer_iter_mut()
1016            .for_each(|mut row| row.assign(init_params));
1017
1018        let n_iters = max_iter
1019            .unwrap_or(self.base_models.len())
1020            .min(self.base_models.len());
1021
1022        for (i, (learners, col_idx)) in self
1023            .base_models
1024            .iter()
1025            .zip(self.col_idxs.iter())
1026            .enumerate()
1027            .take(n_iters)
1028        {
1029            let scale = self.scalings[i];
1030
1031            // Apply column subsampling during prediction to match training
1032            // This is critical when col_sample < 1.0
1033            let predictions_cols: Vec<Array1<f64>> = if col_idx.len() == x.ncols() {
1034                learners.iter().map(|learner| learner.predict(x)).collect()
1035            } else {
1036                let x_subset = x.select(ndarray::Axis(1), col_idx);
1037                learners
1038                    .iter()
1039                    .map(|learner| learner.predict(&x_subset))
1040                    .collect()
1041            };
1042
1043            let predictions = to_2d_array(predictions_cols);
1044
1045            params -= &(self.learning_rate * scale * &predictions);
1046        }
1047        params
1048    }
1049
1050    /// Get the predicted distribution parameters (like Python's pred_param)
1051    pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
1052        self.get_params(x)
1053    }
1054
1055    /// Get the predicted distribution parameters up to a specific iteration
1056    pub fn pred_param_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
1057        self.get_params_at(x, Some(max_iter))
1058    }
1059
1060    pub fn pred_dist(&self, x: &Array2<f64>) -> D {
1061        let params = self.get_params(x);
1062        D::from_params(&params)
1063    }
1064
1065    /// Get the predicted distribution up to a specific iteration
1066    pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> D {
1067        let params = self.get_params_at(x, Some(max_iter));
1068        D::from_params(&params)
1069    }
1070
1071    pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1072        self.pred_dist(x).predict()
1073    }
1074
1075    /// Get predictions up to a specific iteration
1076    pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1077        self.pred_dist_at(x, max_iter).predict()
1078    }
1079
1080    /// Returns an iterator over staged predictions (predictions at each boosting iteration)
1081    pub fn staged_predict<'a>(
1082        &'a self,
1083        x: &'a Array2<f64>,
1084    ) -> impl Iterator<Item = Array1<f64>> + 'a {
1085        (1..=self.base_models.len()).map(move |i| self.predict_at(x, i))
1086    }
1087
1088    /// Returns an iterator over staged distribution predictions
1089    pub fn staged_pred_dist<'a>(&'a self, x: &'a Array2<f64>) -> impl Iterator<Item = D> + 'a {
1090        (1..=self.base_models.len()).map(move |i| self.pred_dist_at(x, i))
1091    }
1092
1093    /// Compute the average score (loss) on the given data
1094    pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
1095        let dist = self.pred_dist(x);
1096        Scorable::total_score(&dist, y, None)
1097    }
1098
1099    /// Get number of features the model was trained on
1100    pub fn n_features(&self) -> Option<usize> {
1101        self.n_features
1102    }
1103
1104    /// Determine if verbose output should be printed at the given iteration.
1105    /// Handles both integer intervals (verbose_eval >= 1.0) and percentage intervals (0 < verbose_eval < 1.0).
1106    fn should_print_verbose(&self, iteration: u32) -> bool {
1107        if !self.verbose || self.verbose_eval <= 0.0 {
1108            return false;
1109        }
1110
1111        // Compute verbose_eval interval:
1112        // - If >= 1.0: use as integer iteration count (e.g., 100 = every 100 iterations)
1113        // - If 0 < x < 1.0: use as percentage of n_estimators (e.g., 0.1 = every 10%)
1114        let verbose_interval = if self.verbose_eval >= 1.0 {
1115            self.verbose_eval as u32
1116        } else {
1117            // Percentage of total iterations
1118            (self.n_estimators as f64 * self.verbose_eval).max(1.0) as u32
1119        };
1120
1121        verbose_interval > 0 && iteration % verbose_interval == 0
1122    }
1123
1124    /// Compute the effective learning rate for the given iteration using the configured schedule.
1125    fn compute_learning_rate(&self, iteration: u32, progress: f64) -> f64 {
1126        // Legacy adaptive_learning_rate takes precedence for backward compatibility
1127        if self.adaptive_learning_rate {
1128            return self.learning_rate * (1.0 - 0.7 * progress).max(0.1);
1129        }
1130
1131        match self.lr_schedule {
1132            LearningRateSchedule::Constant => self.learning_rate,
1133            LearningRateSchedule::Linear {
1134                decay_rate,
1135                min_lr_fraction,
1136            } => self.learning_rate * (1.0 - decay_rate * progress).max(min_lr_fraction),
1137            LearningRateSchedule::Exponential { decay_rate } => {
1138                self.learning_rate * (-decay_rate * progress).exp()
1139            }
1140            LearningRateSchedule::Cosine => {
1141                self.learning_rate * 0.5 * (1.0 + (std::f64::consts::PI * progress).cos())
1142            }
1143            LearningRateSchedule::CosineWarmRestarts { restart_period } => {
1144                let period_progress = (iteration % restart_period) as f64 / restart_period as f64;
1145                self.learning_rate * 0.5 * (1.0 + (std::f64::consts::PI * period_progress).cos())
1146            }
1147        }
1148    }
1149
1150    /// Compute feature importances based on how often each feature is used in splits.
1151    /// Returns a 2D array of shape (n_params, n_features) where each row contains
1152    /// the normalized feature importances for that distribution parameter.
1153    /// Returns None if the model hasn't been trained or has no features.
1154    pub fn feature_importances(&self) -> Option<Array2<f64>> {
1155        let n_features = self.n_features?;
1156        if self.base_models.is_empty() || n_features == 0 {
1157            return None;
1158        }
1159
1160        let n_params = self.init_params.as_ref()?.len();
1161        let mut importances = Array2::zeros((n_params, n_features));
1162
1163        // Aggregate feature usage across all iterations, weighted by scaling factor
1164        for (iter_idx, learners) in self.base_models.iter().enumerate() {
1165            let scale = self.scalings[iter_idx].abs();
1166
1167            for (param_idx, learner) in learners.iter().enumerate() {
1168                if let Some(feature_idx) = learner.split_feature() {
1169                    if feature_idx < n_features {
1170                        importances[[param_idx, feature_idx]] += scale;
1171                    }
1172                }
1173            }
1174        }
1175
1176        // Normalize each parameter's importances to sum to 1
1177        for mut row in importances.rows_mut() {
1178            let sum: f64 = row.sum();
1179            if sum > 0.0 {
1180                row.mapv_inplace(|v| v / sum);
1181            }
1182        }
1183
1184        Some(importances)
1185    }
1186
1187    /// Calibrate uncertainty estimates using isotonic regression on validation data
1188    /// This improves the quality of probabilistic predictions by adjusting the variance estimates
1189    pub fn calibrate_uncertainty(
1190        &mut self,
1191        x_val: &Array2<f64>,
1192        y_val: &Array1<f64>,
1193    ) -> Result<(), &'static str> {
1194        if self.base_models.is_empty() {
1195            return Err("Model must be trained before calibration");
1196        }
1197
1198        // Get predictions on validation data
1199        let params = self.pred_param(x_val);
1200        let dist = D::from_params(&params);
1201
1202        // Calculate predictions and errors
1203        let predictions = dist.predict();
1204        let errors = y_val - &predictions;
1205
1206        // Calculate empirical variance
1207        let empirical_var = errors.mapv(|e| e * e).mean().unwrap_or(1.0);
1208
1209        // For normal distribution (2 parameters), adjust the scale parameter
1210        if let Some(init_params) = self.init_params.as_mut() {
1211            if init_params.len() >= 2 {
1212                // The second parameter is log(scale), so we adjust it based on empirical variance
1213                let current_var = (-init_params[1]).exp(); // exp(2*log(scale)) = scale^2
1214                let target_var = empirical_var;
1215                let calibration_factor = (target_var / current_var).sqrt();
1216                init_params[1] += calibration_factor.ln();
1217            }
1218        }
1219
1220        Ok(())
1221    }
1222
1223    /// Compute aggregated feature importances across all distribution parameters.
1224    /// Returns a 1D array of length n_features with normalized importances.
1225    pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
1226        let importances = self.feature_importances()?;
1227        let mut aggregated = importances.sum_axis(ndarray::Axis(0));
1228
1229        let sum: f64 = aggregated.sum();
1230        if sum > 0.0 {
1231            aggregated.mapv_inplace(|v| v / sum);
1232        }
1233
1234        Some(aggregated)
1235    }
1236
1237    fn line_search(
1238        &self,
1239        resids: &Array2<f64>,
1240        start: &Array2<f64>,
1241        y: &Array1<f64>,
1242        sample_weight: Option<&Array1<f64>>,
1243    ) -> f64 {
1244        match self.line_search_method {
1245            LineSearchMethod::Binary => self.line_search_binary(resids, start, y, sample_weight),
1246            LineSearchMethod::GoldenSection { max_iters } => {
1247                self.line_search_golden_section(resids, start, y, sample_weight, max_iters)
1248            }
1249        }
1250    }
1251
1252    /// Binary line search (original NGBoost method).
1253    fn line_search_binary(
1254        &self,
1255        resids: &Array2<f64>,
1256        start: &Array2<f64>,
1257        y: &Array1<f64>,
1258        sample_weight: Option<&Array1<f64>>,
1259    ) -> f64 {
1260        let mut scale = 1.0;
1261        let initial_score = Scorable::total_score(&D::from_params(start), y, sample_weight);
1262
1263        // Scale up phase: try to find a larger step that still reduces loss
1264        loop {
1265            if scale > 256.0 {
1266                break;
1267            }
1268            let scaled_resids = resids * (scale * 2.0);
1269            let next_params = start - &scaled_resids;
1270            let score = Scorable::total_score(&D::from_params(&next_params), y, sample_weight);
1271            if score >= initial_score || !score.is_finite() {
1272                break;
1273            }
1274            scale *= 2.0;
1275        }
1276
1277        // Scale down phase: find a step that actually reduces loss
1278        loop {
1279            let scaled_resids = resids * scale;
1280            let norm: f64 = scaled_resids
1281                .rows()
1282                .into_iter()
1283                .map(|row| row.iter().map(|x| x * x).sum::<f64>().sqrt())
1284                .sum::<f64>()
1285                / scaled_resids.nrows() as f64;
1286            if norm < self.tol {
1287                break;
1288            }
1289
1290            let next_params = start - &scaled_resids;
1291            let score = Scorable::total_score(&D::from_params(&next_params), y, sample_weight);
1292            if score < initial_score && score.is_finite() {
1293                break;
1294            }
1295            scale *= 0.5;
1296
1297            if scale < 1e-10 {
1298                break;
1299            }
1300        }
1301
1302        scale
1303    }
1304
1305    /// Golden section line search for more accurate step size.
1306    /// Uses the golden ratio to efficiently narrow down the optimal step size.
1307    fn line_search_golden_section(
1308        &self,
1309        resids: &Array2<f64>,
1310        start: &Array2<f64>,
1311        y: &Array1<f64>,
1312        sample_weight: Option<&Array1<f64>>,
1313        max_iters: usize,
1314    ) -> f64 {
1315        // Helper to compute score at a given scale
1316        let compute_score = |scale: f64| -> f64 {
1317            let scaled_resids = resids * scale;
1318            let next_params = start - &scaled_resids;
1319            Scorable::total_score(&D::from_params(&next_params), y, sample_weight)
1320        };
1321
1322        let initial_score = compute_score(0.0);
1323
1324        // First, find a reasonable upper bound by scaling up
1325        let mut upper = 1.0;
1326        while upper < 256.0 {
1327            let score = compute_score(upper * 2.0);
1328            if score >= initial_score || !score.is_finite() {
1329                break;
1330            }
1331            upper *= 2.0;
1332        }
1333
1334        // Golden section search between 0 and upper
1335        let mut a = 0.0;
1336        let mut b = upper;
1337        let inv_phi = 1.0 / GOLDEN_RATIO;
1338        let _inv_phi2 = 1.0 / (GOLDEN_RATIO * GOLDEN_RATIO); // Available for Brent's method extension
1339
1340        // Initial interior points
1341        let mut c = b - (b - a) * inv_phi;
1342        let mut d = a + (b - a) * inv_phi;
1343        let mut fc = compute_score(c);
1344        let mut fd = compute_score(d);
1345
1346        for _ in 0..max_iters {
1347            if (b - a).abs() < self.tol {
1348                break;
1349            }
1350
1351            if fc < fd {
1352                // Minimum is in [a, d]
1353                b = d;
1354                d = c;
1355                fd = fc;
1356                c = b - (b - a) * inv_phi;
1357                fc = compute_score(c);
1358            } else {
1359                // Minimum is in [c, b]
1360                a = c;
1361                c = d;
1362                fc = fd;
1363                d = a + (b - a) * inv_phi;
1364                fd = compute_score(d);
1365            }
1366        }
1367
1368        // Return the midpoint of the final interval
1369        let scale = (a + b) / 2.0;
1370
1371        // Verify the scale actually reduces loss, otherwise fall back
1372        let final_score = compute_score(scale);
1373        if final_score < initial_score && final_score.is_finite() {
1374            scale
1375        } else {
1376            // Fall back to a small step
1377            1.0
1378        }
1379    }
1380
1381    /// Serialize the model to a platform-independent format
1382    pub fn serialize(&self) -> Result<SerializedNGBoost, Box<dyn std::error::Error>> {
1383        // Serialize base models
1384        let serialized_base_models: Vec<Vec<crate::learners::SerializableTrainedLearner>> = self
1385            .base_models
1386            .iter()
1387            .map(|learners| {
1388                learners
1389                    .iter()
1390                    .filter_map(|learner| learner.to_serializable())
1391                    .collect()
1392            })
1393            .collect();
1394
1395        Ok(SerializedNGBoost {
1396            n_estimators: self.n_estimators,
1397            learning_rate: self.learning_rate,
1398            natural_gradient: self.natural_gradient,
1399            minibatch_frac: self.minibatch_frac,
1400            col_sample: self.col_sample,
1401            verbose: self.verbose,
1402            verbose_eval: self.verbose_eval,
1403            tol: self.tol,
1404            early_stopping_rounds: self.early_stopping_rounds,
1405            validation_fraction: self.validation_fraction,
1406            init_params: self.init_params.as_ref().map(|p| p.to_vec()),
1407            scalings: self.scalings.clone(),
1408            col_idxs: self.col_idxs.clone(),
1409            best_val_loss_itr: self.best_val_loss_itr,
1410            base_models: serialized_base_models,
1411            lr_schedule: self.lr_schedule,
1412            tikhonov_reg: self.tikhonov_reg,
1413            line_search_method: self.line_search_method,
1414            n_features: self.n_features,
1415            random_state: self.random_state,
1416        })
1417    }
1418
1419    /// Deserialize the model from a platform-independent format
1420    pub fn deserialize(
1421        serialized: SerializedNGBoost,
1422        base_learner: B,
1423    ) -> Result<Self, Box<dyn std::error::Error>>
1424    where
1425        D: Distribution + Scorable<S> + Clone,
1426        S: Score,
1427        B: BaseLearner + Clone,
1428    {
1429        let mut model = Self::with_options_seeded(
1430            serialized.n_estimators,
1431            serialized.learning_rate,
1432            base_learner,
1433            serialized.natural_gradient,
1434            serialized.minibatch_frac,
1435            serialized.col_sample,
1436            serialized.verbose,
1437            serialized.verbose_eval,
1438            serialized.tol,
1439            serialized.early_stopping_rounds,
1440            serialized.validation_fraction,
1441            false, // Default adaptive_learning_rate to false for backward compatibility
1442            serialized.random_state,
1443        );
1444
1445        // Restore trained state
1446        if let Some(init_params) = serialized.init_params {
1447            model.init_params = Some(Array1::from(init_params));
1448        }
1449        model.scalings = serialized.scalings;
1450        model.col_idxs = serialized.col_idxs;
1451        model.best_val_loss_itr = serialized.best_val_loss_itr;
1452
1453        // Restore advanced options (added in v0.3)
1454        model.lr_schedule = serialized.lr_schedule;
1455        model.tikhonov_reg = serialized.tikhonov_reg;
1456        model.line_search_method = serialized.line_search_method;
1457        model.n_features = serialized.n_features;
1458
1459        // Restore base models
1460        model.base_models = serialized
1461            .base_models
1462            .into_iter()
1463            .map(|learners| learners.into_iter().map(|l| l.to_trait_object()).collect())
1464            .collect();
1465
1466        Ok(model)
1467    }
1468}
1469
1470/// Serialized model data structure
1471#[derive(serde::Serialize, serde::Deserialize)]
1472pub struct SerializedNGBoost {
1473    pub n_estimators: u32,
1474    pub learning_rate: f64,
1475    pub natural_gradient: bool,
1476    pub minibatch_frac: f64,
1477    pub col_sample: f64,
1478    pub verbose: bool,
1479    /// Verbose evaluation interval. Can be >= 1.0 for iteration count or 0 < x < 1.0 for percentage.
1480    pub verbose_eval: f64,
1481    pub tol: f64,
1482    pub early_stopping_rounds: Option<u32>,
1483    pub validation_fraction: f64,
1484    pub init_params: Option<Vec<f64>>,
1485    pub scalings: Vec<f64>,
1486    pub col_idxs: Vec<Vec<usize>>,
1487    pub best_val_loss_itr: Option<usize>,
1488    /// Serialized base models - each inner Vec contains learners for each parameter
1489    pub base_models: Vec<Vec<crate::learners::SerializableTrainedLearner>>,
1490    /// Learning rate schedule (added in v0.3)
1491    #[serde(default)]
1492    pub lr_schedule: LearningRateSchedule,
1493    /// Tikhonov regularization parameter (added in v0.3)
1494    #[serde(default)]
1495    pub tikhonov_reg: f64,
1496    /// Line search method (added in v0.3)
1497    #[serde(default)]
1498    pub line_search_method: LineSearchMethod,
1499    /// Number of features the model was trained on (added in v0.3)
1500    #[serde(default)]
1501    pub n_features: Option<usize>,
1502    /// Random state for reproducibility (added in v0.3)
1503    #[serde(default)]
1504    pub random_state: Option<u64>,
1505}
1506
1507fn to_2d_array(cols: Vec<Array1<f64>>) -> Array2<f64> {
1508    if cols.is_empty() {
1509        return Array2::zeros((0, 0));
1510    }
1511    let nrows = cols[0].len();
1512    let ncols = cols.len();
1513    let mut arr = Array2::zeros((nrows, ncols));
1514    for (j, col) in cols.iter().enumerate() {
1515        arr.column_mut(j).assign(col);
1516    }
1517    arr
1518}
1519
1520// High-level API
1521pub struct NGBRegressor {
1522    model: NGBoost<Normal, LogScore, DecisionTreeLearner>,
1523}
1524
1525pub struct NGBClassifier {
1526    model: NGBoost<Bernoulli, LogScore, DecisionTreeLearner>,
1527}
1528
1529impl NGBRegressor {
1530    pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
1531        Self {
1532            model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
1533        }
1534    }
1535
1536    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1537        self.model.fit(x, y)
1538    }
1539
1540    pub fn fit_with_validation(
1541        &mut self,
1542        x: &Array2<f64>,
1543        y: &Array1<f64>,
1544        x_val: Option<&Array2<f64>>,
1545        y_val: Option<&Array1<f64>>,
1546    ) -> Result<(), &'static str> {
1547        self.model
1548            .fit_with_validation(x, y, x_val, y_val, None, None)
1549    }
1550
1551    /// Fits an NGBoost model to the data appending base models to the existing ones.
1552    ///
1553    /// NOTE: This method is similar to Python's partial_fit. The first call will be the most
1554    /// significant and later calls will retune the model to newer data.
1555    ///
1556    /// Unlike `fit()`, this method does NOT reset the model state, allowing incremental learning.
1557    pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1558        self.model.partial_fit(x, y)
1559    }
1560
1561    /// Partial fit with validation data support.
1562    pub fn partial_fit_with_validation(
1563        &mut self,
1564        x: &Array2<f64>,
1565        y: &Array1<f64>,
1566        x_val: Option<&Array2<f64>>,
1567        y_val: Option<&Array1<f64>>,
1568    ) -> Result<(), &'static str> {
1569        self.model
1570            .partial_fit_with_validation(x, y, x_val, y_val, None, None)
1571    }
1572
1573    pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1574        self.model.predict(x)
1575    }
1576
1577    /// Get predictions up to a specific iteration
1578    pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1579        self.model.predict_at(x, max_iter)
1580    }
1581
1582    /// Returns an iterator over staged predictions
1583    pub fn staged_predict<'a>(
1584        &'a self,
1585        x: &'a Array2<f64>,
1586    ) -> impl Iterator<Item = Array1<f64>> + 'a {
1587        self.model.staged_predict(x)
1588    }
1589
1590    pub fn pred_dist(&self, x: &Array2<f64>) -> Normal {
1591        self.model.pred_dist(x)
1592    }
1593
1594    /// Get the predicted distribution up to a specific iteration
1595    pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Normal {
1596        self.model.pred_dist_at(x, max_iter)
1597    }
1598
1599    /// Returns an iterator over staged distribution predictions
1600    pub fn staged_pred_dist<'a>(&'a self, x: &'a Array2<f64>) -> impl Iterator<Item = Normal> + 'a {
1601        self.model.staged_pred_dist(x)
1602    }
1603
1604    /// Get the predicted distribution parameters
1605    pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
1606        self.model.pred_param(x)
1607    }
1608
1609    /// Compute the average score (loss) on the given data
1610    pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
1611        self.model.score(x, y)
1612    }
1613
1614    /// Set a custom training loss monitor function
1615    pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
1616    where
1617        F: Fn(&Normal, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1618    {
1619        self.model.set_train_loss_monitor(Box::new(monitor));
1620    }
1621
1622    /// Set a custom validation loss monitor function
1623    pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
1624    where
1625        F: Fn(&Normal, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1626    {
1627        self.model.set_val_loss_monitor(Box::new(monitor));
1628    }
1629
1630    /// Enhanced constructor with all options
1631    pub fn with_options(
1632        n_estimators: u32,
1633        learning_rate: f64,
1634        natural_gradient: bool,
1635        minibatch_frac: f64,
1636        col_sample: f64,
1637        verbose: bool,
1638        verbose_eval: f64,
1639        tol: f64,
1640        early_stopping_rounds: Option<u32>,
1641        validation_fraction: f64,
1642        adaptive_learning_rate: bool,
1643    ) -> Self {
1644        Self {
1645            model: NGBoost::with_options(
1646                n_estimators,
1647                learning_rate,
1648                default_tree_learner(),
1649                natural_gradient,
1650                minibatch_frac,
1651                col_sample,
1652                verbose,
1653                verbose_eval,
1654                tol,
1655                early_stopping_rounds,
1656                validation_fraction,
1657                adaptive_learning_rate,
1658            ),
1659        }
1660    }
1661
1662    /// Enhanced constructor with all options (backward compatible version without adaptive_learning_rate)
1663    pub fn with_options_compat(
1664        n_estimators: u32,
1665        learning_rate: f64,
1666        natural_gradient: bool,
1667        minibatch_frac: f64,
1668        col_sample: f64,
1669        verbose: bool,
1670        verbose_eval: f64,
1671        tol: f64,
1672        early_stopping_rounds: Option<u32>,
1673        validation_fraction: f64,
1674    ) -> Self {
1675        Self::with_options(
1676            n_estimators,
1677            learning_rate,
1678            natural_gradient,
1679            minibatch_frac,
1680            col_sample,
1681            verbose,
1682            verbose_eval,
1683            tol,
1684            early_stopping_rounds,
1685            validation_fraction,
1686            false, // Default adaptive_learning_rate to false
1687        )
1688    }
1689
1690    /// Enable adaptive learning rate for better convergence in probabilistic forecasting
1691    pub fn set_adaptive_learning_rate(&mut self, enabled: bool) {
1692        self.model.adaptive_learning_rate = enabled;
1693    }
1694
1695    /// Calibrate uncertainty estimates using validation data
1696    /// This improves the quality of probabilistic predictions
1697    pub fn calibrate_uncertainty(
1698        &mut self,
1699        x_val: &Array2<f64>,
1700        y_val: &Array1<f64>,
1701    ) -> Result<(), &'static str> {
1702        self.model.calibrate_uncertainty(x_val, y_val)
1703    }
1704
1705    /// Get the number of estimators (boosting iterations)
1706    pub fn n_estimators(&self) -> u32 {
1707        self.model.n_estimators
1708    }
1709
1710    /// Get the learning rate
1711    pub fn learning_rate(&self) -> f64 {
1712        self.model.learning_rate
1713    }
1714
1715    /// Get whether natural gradient is used
1716    pub fn natural_gradient(&self) -> bool {
1717        self.model.natural_gradient
1718    }
1719
1720    /// Get the minibatch fraction
1721    pub fn minibatch_frac(&self) -> f64 {
1722        self.model.minibatch_frac
1723    }
1724
1725    /// Get the column sampling fraction
1726    pub fn col_sample(&self) -> f64 {
1727        self.model.col_sample
1728    }
1729
1730    /// Get the best validation iteration
1731    pub fn best_val_loss_itr(&self) -> Option<usize> {
1732        self.model.best_val_loss_itr
1733    }
1734
1735    /// Get early stopping rounds
1736    pub fn early_stopping_rounds(&self) -> Option<u32> {
1737        self.model.early_stopping_rounds
1738    }
1739
1740    /// Get validation fraction
1741    pub fn validation_fraction(&self) -> f64 {
1742        self.model.validation_fraction
1743    }
1744
1745    /// Get number of features the model was trained on
1746    pub fn n_features(&self) -> Option<usize> {
1747        self.model.n_features()
1748    }
1749
1750    /// Compute feature importances per distribution parameter.
1751    /// Returns a 2D array of shape (n_params, n_features).
1752    pub fn feature_importances(&self) -> Option<Array2<f64>> {
1753        self.model.feature_importances()
1754    }
1755
1756    /// Compute aggregated feature importances across all distribution parameters.
1757    /// Returns a 1D array of length n_features with normalized importances.
1758    pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
1759        self.model.feature_importances_aggregated()
1760    }
1761
1762    /// Save model to file using bincode serialization
1763    pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1764        let serialized = self.model.serialize()?;
1765        let encoded = bincode::serialize(&serialized)?;
1766        std::fs::write(path, encoded)?;
1767        Ok(())
1768    }
1769
1770    /// Load model from file using bincode deserialization
1771    pub fn load_model(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
1772        let encoded = std::fs::read(path)?;
1773        let serialized: SerializedNGBoost = bincode::deserialize(&encoded)?;
1774        let model = NGBoost::<Normal, LogScore, DecisionTreeLearner>::deserialize(
1775            serialized,
1776            default_tree_learner(),
1777        )?;
1778        Ok(Self { model })
1779    }
1780
1781    /// Get the training history (losses at each iteration).
1782    pub fn evals_result(&self) -> &EvalsResult {
1783        self.model.evals_result()
1784    }
1785
1786    /// Set the random seed for reproducibility.
1787    pub fn set_random_state(&mut self, seed: u64) {
1788        self.model.set_random_state(seed);
1789    }
1790
1791    /// Get the current random state (seed), if set.
1792    pub fn random_state(&self) -> Option<u64> {
1793        self.model.random_state()
1794    }
1795
1796    /// Fit with sample weights.
1797    pub fn fit_with_weights(
1798        &mut self,
1799        x: &Array2<f64>,
1800        y: &Array1<f64>,
1801        sample_weight: Option<&Array1<f64>>,
1802    ) -> Result<(), &'static str> {
1803        self.model
1804            .fit_with_validation(x, y, None, None, sample_weight, None)
1805    }
1806
1807    /// Fit with sample weights and validation data.
1808    pub fn fit_with_weights_and_validation(
1809        &mut self,
1810        x: &Array2<f64>,
1811        y: &Array1<f64>,
1812        x_val: Option<&Array2<f64>>,
1813        y_val: Option<&Array1<f64>>,
1814        sample_weight: Option<&Array1<f64>>,
1815        val_sample_weight: Option<&Array1<f64>>,
1816    ) -> Result<(), &'static str> {
1817        self.model
1818            .fit_with_validation(x, y, x_val, y_val, sample_weight, val_sample_weight)
1819    }
1820
1821    /// Get all hyperparameters as a struct.
1822    pub fn get_params(&self) -> NGBoostParams {
1823        NGBoostParams {
1824            n_estimators: self.model.n_estimators,
1825            learning_rate: self.model.learning_rate,
1826            natural_gradient: self.model.natural_gradient,
1827            minibatch_frac: self.model.minibatch_frac,
1828            col_sample: self.model.col_sample,
1829            verbose: self.model.verbose,
1830            verbose_eval: self.model.verbose_eval,
1831            tol: self.model.tol,
1832            early_stopping_rounds: self.model.early_stopping_rounds,
1833            validation_fraction: self.model.validation_fraction,
1834            random_state: self.model.random_state(),
1835            lr_schedule: self.model.lr_schedule,
1836            tikhonov_reg: self.model.tikhonov_reg,
1837            line_search_method: self.model.line_search_method,
1838        }
1839    }
1840
1841    /// Set hyperparameters from a struct.
1842    /// Note: This only sets hyperparameters, not trained state.
1843    pub fn set_params(&mut self, params: NGBoostParams) {
1844        self.model.n_estimators = params.n_estimators;
1845        self.model.learning_rate = params.learning_rate;
1846        self.model.natural_gradient = params.natural_gradient;
1847        self.model.minibatch_frac = params.minibatch_frac;
1848        self.model.col_sample = params.col_sample;
1849        self.model.verbose = params.verbose;
1850        self.model.verbose_eval = params.verbose_eval;
1851        self.model.tol = params.tol;
1852        self.model.early_stopping_rounds = params.early_stopping_rounds;
1853        self.model.validation_fraction = params.validation_fraction;
1854        self.model.lr_schedule = params.lr_schedule;
1855        self.model.tikhonov_reg = params.tikhonov_reg;
1856        self.model.line_search_method = params.line_search_method;
1857        if let Some(seed) = params.random_state {
1858            self.model.set_random_state(seed);
1859        }
1860    }
1861}
1862
1863/// Multi-class classifier using NGBoost with Categorical distribution.
1864///
1865/// This is a generic struct parameterized by the number of classes K.
1866/// For binary classification, use `NGBClassifier` instead.
1867///
1868/// # Example
1869/// ```ignore
1870/// use ngboost_rs::ngboost::NGBMultiClassifier;
1871///
1872/// // Create a 5-class classifier
1873/// let mut model: NGBMultiClassifier<5> = NGBMultiClassifier::new(100, 0.1);
1874/// model.fit(&x_train, &y_train).unwrap();
1875/// let probs = model.predict_proba(&x_test);
1876/// ```
1877pub struct NGBMultiClassifier<const K: usize> {
1878    model: NGBoost<Categorical<K>, LogScore, DecisionTreeLearner>,
1879}
1880
1881impl<const K: usize> NGBMultiClassifier<K> {
1882    pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
1883        Self {
1884            model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
1885        }
1886    }
1887
1888    pub fn with_options(
1889        n_estimators: u32,
1890        learning_rate: f64,
1891        natural_gradient: bool,
1892        minibatch_frac: f64,
1893        col_sample: f64,
1894        verbose: bool,
1895        verbose_eval: f64,
1896        tol: f64,
1897        early_stopping_rounds: Option<u32>,
1898        validation_fraction: f64,
1899        adaptive_learning_rate: bool,
1900    ) -> Self {
1901        Self {
1902            model: NGBoost::with_options(
1903                n_estimators,
1904                learning_rate,
1905                default_tree_learner(),
1906                natural_gradient,
1907                minibatch_frac,
1908                col_sample,
1909                verbose,
1910                verbose_eval,
1911                tol,
1912                early_stopping_rounds,
1913                validation_fraction,
1914                adaptive_learning_rate,
1915            ),
1916        }
1917    }
1918
1919    /// Set a custom training loss monitor function
1920    pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
1921    where
1922        F: Fn(&Categorical<K>, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1923    {
1924        self.model.set_train_loss_monitor(Box::new(monitor));
1925    }
1926
1927    /// Set a custom validation loss monitor function
1928    pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
1929    where
1930        F: Fn(&Categorical<K>, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
1931    {
1932        self.model.set_val_loss_monitor(Box::new(monitor));
1933    }
1934
1935    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1936        self.model.fit(x, y)
1937    }
1938
1939    pub fn fit_with_validation(
1940        &mut self,
1941        x: &Array2<f64>,
1942        y: &Array1<f64>,
1943        x_val: Option<&Array2<f64>>,
1944        y_val: Option<&Array1<f64>>,
1945    ) -> Result<(), &'static str> {
1946        self.model
1947            .fit_with_validation(x, y, x_val, y_val, None, None)
1948    }
1949
1950    /// Fits an NGBoost model to the data appending base models to the existing ones.
1951    pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
1952        self.model.partial_fit(x, y)
1953    }
1954
1955    /// Partial fit with validation data support.
1956    pub fn partial_fit_with_validation(
1957        &mut self,
1958        x: &Array2<f64>,
1959        y: &Array1<f64>,
1960        x_val: Option<&Array2<f64>>,
1961        y_val: Option<&Array1<f64>>,
1962    ) -> Result<(), &'static str> {
1963        self.model
1964            .partial_fit_with_validation(x, y, x_val, y_val, None, None)
1965    }
1966
1967    /// Predict class labels.
1968    pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
1969        self.model.predict(x)
1970    }
1971
1972    /// Get predictions up to a specific iteration.
1973    pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
1974        self.model.predict_at(x, max_iter)
1975    }
1976
1977    /// Returns an iterator over staged predictions.
1978    pub fn staged_predict<'a>(
1979        &'a self,
1980        x: &'a Array2<f64>,
1981    ) -> impl Iterator<Item = Array1<f64>> + 'a {
1982        self.model.staged_predict(x)
1983    }
1984
1985    /// Predict class probabilities.
1986    /// Returns a (N, K) array where N is the number of samples and K is the number of classes.
1987    pub fn predict_proba(&self, x: &Array2<f64>) -> Array2<f64> {
1988        let dist = self.model.pred_dist(x);
1989        dist.class_probs()
1990    }
1991
1992    /// Get class probabilities up to a specific iteration.
1993    pub fn predict_proba_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
1994        let dist = self.model.pred_dist_at(x, max_iter);
1995        dist.class_probs()
1996    }
1997
1998    /// Returns an iterator over staged probability predictions.
1999    pub fn staged_predict_proba<'a>(
2000        &'a self,
2001        x: &'a Array2<f64>,
2002    ) -> impl Iterator<Item = Array2<f64>> + 'a {
2003        (1..=self.model.base_models.len()).map(move |i| self.predict_proba_at(x, i))
2004    }
2005
2006    /// Get the predicted distribution.
2007    pub fn pred_dist(&self, x: &Array2<f64>) -> Categorical<K> {
2008        self.model.pred_dist(x)
2009    }
2010
2011    /// Get the predicted distribution up to a specific iteration.
2012    pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Categorical<K> {
2013        self.model.pred_dist_at(x, max_iter)
2014    }
2015
2016    /// Returns an iterator over staged distribution predictions.
2017    pub fn staged_pred_dist<'a>(
2018        &'a self,
2019        x: &'a Array2<f64>,
2020    ) -> impl Iterator<Item = Categorical<K>> + 'a {
2021        self.model.staged_pred_dist(x)
2022    }
2023
2024    /// Get the predicted distribution parameters.
2025    pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
2026        self.model.pred_param(x)
2027    }
2028
2029    /// Compute the average score (loss) on the given data.
2030    pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
2031        self.model.score(x, y)
2032    }
2033
2034    /// Get the number of estimators (boosting iterations).
2035    pub fn n_estimators(&self) -> u32 {
2036        self.model.n_estimators
2037    }
2038
2039    /// Get the learning rate.
2040    pub fn learning_rate(&self) -> f64 {
2041        self.model.learning_rate
2042    }
2043
2044    /// Get whether natural gradient is used.
2045    pub fn natural_gradient(&self) -> bool {
2046        self.model.natural_gradient
2047    }
2048
2049    /// Get the minibatch fraction.
2050    pub fn minibatch_frac(&self) -> f64 {
2051        self.model.minibatch_frac
2052    }
2053
2054    /// Get the column sampling fraction.
2055    pub fn col_sample(&self) -> f64 {
2056        self.model.col_sample
2057    }
2058
2059    /// Get the best validation iteration.
2060    pub fn best_val_loss_itr(&self) -> Option<usize> {
2061        self.model.best_val_loss_itr
2062    }
2063
2064    /// Get early stopping rounds.
2065    pub fn early_stopping_rounds(&self) -> Option<u32> {
2066        self.model.early_stopping_rounds
2067    }
2068
2069    /// Get validation fraction.
2070    pub fn validation_fraction(&self) -> f64 {
2071        self.model.validation_fraction
2072    }
2073
2074    /// Get number of features the model was trained on.
2075    pub fn n_features(&self) -> Option<usize> {
2076        self.model.n_features()
2077    }
2078
2079    /// Compute feature importances per distribution parameter.
2080    pub fn feature_importances(&self) -> Option<Array2<f64>> {
2081        self.model.feature_importances()
2082    }
2083
2084    /// Compute aggregated feature importances across all distribution parameters.
2085    pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
2086        self.model.feature_importances_aggregated()
2087    }
2088
2089    /// Get the training history (losses at each iteration).
2090    pub fn evals_result(&self) -> &EvalsResult {
2091        self.model.evals_result()
2092    }
2093
2094    /// Set the random seed for reproducibility.
2095    pub fn set_random_state(&mut self, seed: u64) {
2096        self.model.set_random_state(seed);
2097    }
2098
2099    /// Get the current random state (seed), if set.
2100    pub fn random_state(&self) -> Option<u64> {
2101        self.model.random_state()
2102    }
2103
2104    /// Fit with sample weights.
2105    pub fn fit_with_weights(
2106        &mut self,
2107        x: &Array2<f64>,
2108        y: &Array1<f64>,
2109        sample_weight: Option<&Array1<f64>>,
2110    ) -> Result<(), &'static str> {
2111        self.model
2112            .fit_with_validation(x, y, None, None, sample_weight, None)
2113    }
2114
2115    /// Fit with sample weights and validation data.
2116    pub fn fit_with_weights_and_validation(
2117        &mut self,
2118        x: &Array2<f64>,
2119        y: &Array1<f64>,
2120        x_val: Option<&Array2<f64>>,
2121        y_val: Option<&Array1<f64>>,
2122        sample_weight: Option<&Array1<f64>>,
2123        val_sample_weight: Option<&Array1<f64>>,
2124    ) -> Result<(), &'static str> {
2125        self.model
2126            .fit_with_validation(x, y, x_val, y_val, sample_weight, val_sample_weight)
2127    }
2128
2129    /// Get all hyperparameters as a struct.
2130    pub fn get_params(&self) -> NGBoostParams {
2131        NGBoostParams {
2132            n_estimators: self.model.n_estimators,
2133            learning_rate: self.model.learning_rate,
2134            natural_gradient: self.model.natural_gradient,
2135            minibatch_frac: self.model.minibatch_frac,
2136            col_sample: self.model.col_sample,
2137            verbose: self.model.verbose,
2138            verbose_eval: self.model.verbose_eval,
2139            tol: self.model.tol,
2140            early_stopping_rounds: self.model.early_stopping_rounds,
2141            validation_fraction: self.model.validation_fraction,
2142            random_state: self.model.random_state(),
2143            lr_schedule: self.model.lr_schedule,
2144            tikhonov_reg: self.model.tikhonov_reg,
2145            line_search_method: self.model.line_search_method,
2146        }
2147    }
2148
2149    /// Set hyperparameters from a struct.
2150    pub fn set_params(&mut self, params: NGBoostParams) {
2151        self.model.n_estimators = params.n_estimators;
2152        self.model.learning_rate = params.learning_rate;
2153        self.model.natural_gradient = params.natural_gradient;
2154        self.model.minibatch_frac = params.minibatch_frac;
2155        self.model.col_sample = params.col_sample;
2156        self.model.verbose = params.verbose;
2157        self.model.verbose_eval = params.verbose_eval;
2158        self.model.tol = params.tol;
2159        self.model.early_stopping_rounds = params.early_stopping_rounds;
2160        self.model.validation_fraction = params.validation_fraction;
2161        self.model.lr_schedule = params.lr_schedule;
2162        self.model.tikhonov_reg = params.tikhonov_reg;
2163        self.model.line_search_method = params.line_search_method;
2164        if let Some(seed) = params.random_state {
2165            self.model.set_random_state(seed);
2166        }
2167    }
2168
2169    /// Get the number of classes.
2170    pub fn n_classes(&self) -> usize {
2171        K
2172    }
2173}
2174
2175/// Type alias for 3-class classification.
2176pub type NGBMultiClassifier3 = NGBMultiClassifier<3>;
2177
2178/// Type alias for 4-class classification.
2179pub type NGBMultiClassifier4 = NGBMultiClassifier<4>;
2180
2181/// Type alias for 5-class classification.
2182pub type NGBMultiClassifier5 = NGBMultiClassifier<5>;
2183
2184/// Type alias for 10-class classification (e.g., digit recognition).
2185pub type NGBMultiClassifier10 = NGBMultiClassifier<10>;
2186
2187impl NGBClassifier {
2188    pub fn new(n_estimators: u32, learning_rate: f64) -> Self {
2189        Self {
2190            model: NGBoost::new(n_estimators, learning_rate, default_tree_learner()),
2191        }
2192    }
2193
2194    pub fn with_options(
2195        n_estimators: u32,
2196        learning_rate: f64,
2197        natural_gradient: bool,
2198        minibatch_frac: f64,
2199        col_sample: f64,
2200        verbose: bool,
2201        verbose_eval: f64,
2202        tol: f64,
2203        early_stopping_rounds: Option<u32>,
2204        validation_fraction: f64,
2205        adaptive_learning_rate: bool,
2206    ) -> Self {
2207        Self {
2208            model: NGBoost::with_options(
2209                n_estimators,
2210                learning_rate,
2211                default_tree_learner(),
2212                natural_gradient,
2213                minibatch_frac,
2214                col_sample,
2215                verbose,
2216                verbose_eval,
2217                tol,
2218                early_stopping_rounds,
2219                validation_fraction,
2220                adaptive_learning_rate,
2221            ),
2222        }
2223    }
2224
2225    /// Set a custom training loss monitor function
2226    pub fn set_train_loss_monitor<F>(&mut self, monitor: F)
2227    where
2228        F: Fn(&Bernoulli, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
2229    {
2230        self.model.set_train_loss_monitor(Box::new(monitor));
2231    }
2232
2233    /// Set a custom validation loss monitor function
2234    pub fn set_val_loss_monitor<F>(&mut self, monitor: F)
2235    where
2236        F: Fn(&Bernoulli, &Array1<f64>, Option<&Array1<f64>>) -> f64 + Send + Sync + 'static,
2237    {
2238        self.model.set_val_loss_monitor(Box::new(monitor));
2239    }
2240
2241    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
2242        self.model.fit(x, y)
2243    }
2244
2245    pub fn fit_with_validation(
2246        &mut self,
2247        x: &Array2<f64>,
2248        y: &Array1<f64>,
2249        x_val: Option<&Array2<f64>>,
2250        y_val: Option<&Array1<f64>>,
2251    ) -> Result<(), &'static str> {
2252        self.model
2253            .fit_with_validation(x, y, x_val, y_val, None, None)
2254    }
2255
2256    /// Fits an NGBoost model to the data appending base models to the existing ones.
2257    ///
2258    /// NOTE: This method is similar to Python's partial_fit. The first call will be the most
2259    /// significant and later calls will retune the model to newer data.
2260    ///
2261    /// Unlike `fit()`, this method does NOT reset the model state, allowing incremental learning.
2262    pub fn partial_fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<(), &'static str> {
2263        self.model.partial_fit(x, y)
2264    }
2265
2266    /// Partial fit with validation data support.
2267    pub fn partial_fit_with_validation(
2268        &mut self,
2269        x: &Array2<f64>,
2270        y: &Array1<f64>,
2271        x_val: Option<&Array2<f64>>,
2272        y_val: Option<&Array1<f64>>,
2273    ) -> Result<(), &'static str> {
2274        self.model
2275            .partial_fit_with_validation(x, y, x_val, y_val, None, None)
2276    }
2277
2278    pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
2279        self.model.predict(x)
2280    }
2281
2282    /// Get predictions up to a specific iteration
2283    pub fn predict_at(&self, x: &Array2<f64>, max_iter: usize) -> Array1<f64> {
2284        self.model.predict_at(x, max_iter)
2285    }
2286
2287    /// Returns an iterator over staged predictions
2288    pub fn staged_predict<'a>(
2289        &'a self,
2290        x: &'a Array2<f64>,
2291    ) -> impl Iterator<Item = Array1<f64>> + 'a {
2292        self.model.staged_predict(x)
2293    }
2294
2295    pub fn predict_proba(&self, x: &Array2<f64>) -> Array2<f64> {
2296        let dist = self.model.pred_dist(x);
2297        dist.class_probs()
2298    }
2299
2300    /// Get class probabilities up to a specific iteration
2301    pub fn predict_proba_at(&self, x: &Array2<f64>, max_iter: usize) -> Array2<f64> {
2302        let dist = self.model.pred_dist_at(x, max_iter);
2303        dist.class_probs()
2304    }
2305
2306    /// Returns an iterator over staged probability predictions
2307    pub fn staged_predict_proba<'a>(
2308        &'a self,
2309        x: &'a Array2<f64>,
2310    ) -> impl Iterator<Item = Array2<f64>> + 'a {
2311        (1..=self.model.base_models.len()).map(move |i| self.predict_proba_at(x, i))
2312    }
2313
2314    pub fn pred_dist(&self, x: &Array2<f64>) -> Bernoulli {
2315        self.model.pred_dist(x)
2316    }
2317
2318    /// Get the predicted distribution up to a specific iteration
2319    pub fn pred_dist_at(&self, x: &Array2<f64>, max_iter: usize) -> Bernoulli {
2320        self.model.pred_dist_at(x, max_iter)
2321    }
2322
2323    /// Returns an iterator over staged distribution predictions
2324    pub fn staged_pred_dist<'a>(
2325        &'a self,
2326        x: &'a Array2<f64>,
2327    ) -> impl Iterator<Item = Bernoulli> + 'a {
2328        self.model.staged_pred_dist(x)
2329    }
2330
2331    /// Get the predicted distribution parameters
2332    pub fn pred_param(&self, x: &Array2<f64>) -> Array2<f64> {
2333        self.model.pred_param(x)
2334    }
2335
2336    /// Compute the average score (loss) on the given data
2337    pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> f64 {
2338        self.model.score(x, y)
2339    }
2340
2341    /// Get the number of estimators (boosting iterations)
2342    pub fn n_estimators(&self) -> u32 {
2343        self.model.n_estimators
2344    }
2345
2346    /// Get the learning rate
2347    pub fn learning_rate(&self) -> f64 {
2348        self.model.learning_rate
2349    }
2350
2351    /// Get whether natural gradient is used
2352    pub fn natural_gradient(&self) -> bool {
2353        self.model.natural_gradient
2354    }
2355
2356    /// Get the minibatch fraction
2357    pub fn minibatch_frac(&self) -> f64 {
2358        self.model.minibatch_frac
2359    }
2360
2361    /// Get the column sampling fraction
2362    pub fn col_sample(&self) -> f64 {
2363        self.model.col_sample
2364    }
2365
2366    /// Get the best validation iteration
2367    pub fn best_val_loss_itr(&self) -> Option<usize> {
2368        self.model.best_val_loss_itr
2369    }
2370
2371    /// Get early stopping rounds
2372    pub fn early_stopping_rounds(&self) -> Option<u32> {
2373        self.model.early_stopping_rounds
2374    }
2375
2376    /// Get validation fraction
2377    pub fn validation_fraction(&self) -> f64 {
2378        self.model.validation_fraction
2379    }
2380
2381    /// Get number of features the model was trained on
2382    pub fn n_features(&self) -> Option<usize> {
2383        self.model.n_features()
2384    }
2385
2386    /// Compute feature importances per distribution parameter.
2387    /// Returns a 2D array of shape (n_params, n_features).
2388    pub fn feature_importances(&self) -> Option<Array2<f64>> {
2389        self.model.feature_importances()
2390    }
2391
2392    /// Compute aggregated feature importances across all distribution parameters.
2393    /// Returns a 1D array of length n_features with normalized importances.
2394    pub fn feature_importances_aggregated(&self) -> Option<Array1<f64>> {
2395        self.model.feature_importances_aggregated()
2396    }
2397
2398    /// Save model to file using bincode serialization
2399    pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
2400        let serialized = self.model.serialize()?;
2401        let encoded = bincode::serialize(&serialized)?;
2402        std::fs::write(path, encoded)?;
2403        Ok(())
2404    }
2405
2406    /// Load model from file using bincode deserialization
2407    pub fn load_model(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
2408        let encoded = std::fs::read(path)?;
2409        let serialized: SerializedNGBoost = bincode::deserialize(&encoded)?;
2410        let model = NGBoost::<Bernoulli, LogScore, DecisionTreeLearner>::deserialize(
2411            serialized,
2412            default_tree_learner(),
2413        )?;
2414        Ok(Self { model })
2415    }
2416
2417    /// Calibrate uncertainty estimates using validation data.
2418    /// This improves the quality of probabilistic predictions for classification.
2419    ///
2420    /// Note: For classification, this adjusts the temperature scaling of the logits.
2421    pub fn calibrate_uncertainty(
2422        &mut self,
2423        _x_val: &Array2<f64>,
2424        _y_val: &Array1<f64>,
2425    ) -> Result<(), &'static str> {
2426        // For classification, temperature scaling could be applied here
2427        // Currently a no-op placeholder - classification calibration is more complex
2428        // than regression and typically uses Platt scaling or temperature scaling
2429        Ok(())
2430    }
2431
2432    /// Get the training history (losses at each iteration).
2433    pub fn evals_result(&self) -> &EvalsResult {
2434        self.model.evals_result()
2435    }
2436
2437    /// Set the random seed for reproducibility.
2438    pub fn set_random_state(&mut self, seed: u64) {
2439        self.model.set_random_state(seed);
2440    }
2441
2442    /// Get the current random state (seed), if set.
2443    pub fn random_state(&self) -> Option<u64> {
2444        self.model.random_state()
2445    }
2446
2447    /// Fit with sample weights.
2448    pub fn fit_with_weights(
2449        &mut self,
2450        x: &Array2<f64>,
2451        y: &Array1<f64>,
2452        sample_weight: Option<&Array1<f64>>,
2453    ) -> Result<(), &'static str> {
2454        self.model
2455            .fit_with_validation(x, y, None, None, sample_weight, None)
2456    }
2457
2458    /// Fit with sample weights and validation data.
2459    pub fn fit_with_weights_and_validation(
2460        &mut self,
2461        x: &Array2<f64>,
2462        y: &Array1<f64>,
2463        x_val: Option<&Array2<f64>>,
2464        y_val: Option<&Array1<f64>>,
2465        sample_weight: Option<&Array1<f64>>,
2466        val_sample_weight: Option<&Array1<f64>>,
2467    ) -> Result<(), &'static str> {
2468        self.model
2469            .fit_with_validation(x, y, x_val, y_val, sample_weight, val_sample_weight)
2470    }
2471
2472    /// Get all hyperparameters as a struct.
2473    pub fn get_params(&self) -> NGBoostParams {
2474        NGBoostParams {
2475            n_estimators: self.model.n_estimators,
2476            learning_rate: self.model.learning_rate,
2477            natural_gradient: self.model.natural_gradient,
2478            minibatch_frac: self.model.minibatch_frac,
2479            col_sample: self.model.col_sample,
2480            verbose: self.model.verbose,
2481            verbose_eval: self.model.verbose_eval,
2482            tol: self.model.tol,
2483            early_stopping_rounds: self.model.early_stopping_rounds,
2484            validation_fraction: self.model.validation_fraction,
2485            random_state: self.model.random_state(),
2486            lr_schedule: self.model.lr_schedule,
2487            tikhonov_reg: self.model.tikhonov_reg,
2488            line_search_method: self.model.line_search_method,
2489        }
2490    }
2491
2492    /// Set hyperparameters from a struct.
2493    /// Note: This only sets hyperparameters, not trained state.
2494    pub fn set_params(&mut self, params: NGBoostParams) {
2495        self.model.n_estimators = params.n_estimators;
2496        self.model.learning_rate = params.learning_rate;
2497        self.model.natural_gradient = params.natural_gradient;
2498        self.model.minibatch_frac = params.minibatch_frac;
2499        self.model.col_sample = params.col_sample;
2500        self.model.verbose = params.verbose;
2501        self.model.verbose_eval = params.verbose_eval;
2502        self.model.tol = params.tol;
2503        self.model.early_stopping_rounds = params.early_stopping_rounds;
2504        self.model.validation_fraction = params.validation_fraction;
2505        self.model.lr_schedule = params.lr_schedule;
2506        self.model.tikhonov_reg = params.tikhonov_reg;
2507        self.model.line_search_method = params.line_search_method;
2508        if let Some(seed) = params.random_state {
2509            self.model.set_random_state(seed);
2510        }
2511    }
2512}