Skip to main content

ndarray_glm/fit/
options.rs

1//! Fit-specific configuration and fit builder
2use super::Fit;
3use crate::{Array1, error::RegressionResult, glm::Glm, model::Model, num::Float};
4
5/// A builder struct for fit configuration
6pub struct FitConfig<'a, M, F>
7where
8    M: Glm,
9    F: Float,
10{
11    pub(crate) model: &'a Model<M, F>,
12    pub options: FitOptions<F>,
13}
14
15impl<'a, M, F> FitConfig<'a, M, F>
16where
17    M: Glm,
18    F: Float,
19{
20    pub fn fit(self) -> RegressionResult<Fit<'a, M, F>, F> {
21        M::regression(self.model, self.options)
22    }
23
24    /// Use a maximum number of iterations
25    pub fn max_iter(mut self, max_iter: usize) -> Self {
26        self.options.max_iter = max_iter;
27        self
28    }
29
30    /// Set the tolerance of iteration
31    pub fn tol(mut self, tol: F) -> Self {
32        self.options.tol = tol;
33        self
34    }
35
36    /// Set the L2 (ridge) regularization penalty weight.
37    ///
38    /// NOTE: The fit is sensitive to the scale of the data under L2 regularization. By default,
39    /// the data and parameters are internally standardized so that the contributions from features
40    /// with low variances relative to their offsets are not overly suppressed. The reported
41    /// coefficients are transformed back to the scale of the data, so that they can be applied
42    /// directly to the input data. This default is the recommended approach, and should be
43    /// invisible to the user.
44    ///
45    /// To disable this internal regularization, use
46    /// [`crate::model::ModelBuilderData::no_standardize`].
47    pub fn l2_reg(mut self, l2: F) -> Self {
48        self.options.l2 = l2;
49        self
50    }
51
52    /// Set the L1 (lasso) regularization penalty weight.
53    ///
54    /// L1 regularization incurs the same scale sensitivity as L2 regularization.
55    pub fn l1_reg(mut self, l1: F) -> Self {
56        self.options.l1 = l1;
57        self
58    }
59}
60
61/// Specifies the fitting options
62#[derive(Clone)]
63pub struct FitOptions<F>
64where
65    F: Float,
66{
67    /// The maximum number of IRLS iterations
68    pub max_iter: usize,
69    /// The relative tolerance of the likelihood
70    pub tol: F,
71    /// The regularization of the fit
72    pub l2: F,
73    pub l1: F,
74    /// An initial guess. A sensible default is selected if this is not provided.
75    pub init_guess: Option<Array1<F>>,
76}
77
78impl<F> Default for FitOptions<F>
79where
80    F: Float,
81{
82    fn default() -> Self {
83        Self {
84            max_iter: 128,
85            // This tolerance is rather small, but it is used in the context of a
86            // fraction of the total likelihood.
87            tol: F::epsilon(),
88            l2: F::zero(),
89            l1: F::zero(),
90            init_guess: None,
91        }
92    }
93}