ndarray_glm/fit/options.rs
1//! Fit-specific configuration and fit builder
2use super::Fit;
3use crate::{Array1, error::RegressionResult, glm::Glm, model::Model, num::Float};
4
5/// A builder struct for fit configuration
6pub struct FitConfig<'a, M, F>
7where
8 M: Glm,
9 F: Float,
10{
11 pub(crate) model: &'a Model<M, F>,
12 pub options: FitOptions<F>,
13}
14
15impl<'a, M, F> FitConfig<'a, M, F>
16where
17 M: Glm,
18 F: Float,
19{
20 pub fn fit(self) -> RegressionResult<Fit<'a, M, F>, F> {
21 M::regression(self.model, self.options)
22 }
23
24 /// Use a maximum number of iterations
25 pub fn max_iter(mut self, max_iter: usize) -> Self {
26 self.options.max_iter = max_iter;
27 self
28 }
29
30 /// Set the tolerance of iteration
31 pub fn tol(mut self, tol: F) -> Self {
32 self.options.tol = tol;
33 self
34 }
35
36 /// Set the L2 (ridge) regularization penalty weight.
37 ///
38 /// NOTE: The fit is sensitive to the scale of the data under L2 regularization. By default,
39 /// the data and parameters are internally standardized so that the contributions from features
40 /// with low variances relative to their offsets are not overly suppressed. The reported
41 /// coefficients are transformed back to the scale of the data, so that they can be applied
42 /// directly to the input data. This default is the recommended approach, and should be
43 /// invisible to the user.
44 ///
45 /// To disable this internal regularization, use
46 /// [`crate::model::ModelBuilderData::no_standardize`].
47 pub fn l2_reg(mut self, l2: F) -> Self {
48 self.options.l2 = l2;
49 self
50 }
51
52 /// Set the L1 (lasso) regularization penalty weight.
53 ///
54 /// L1 regularization incurs the same scale sensitivity as L2 regularization.
55 pub fn l1_reg(mut self, l1: F) -> Self {
56 self.options.l1 = l1;
57 self
58 }
59}
60
61/// Specifies the fitting options
62#[derive(Clone)]
63pub struct FitOptions<F>
64where
65 F: Float,
66{
67 /// The maximum number of IRLS iterations
68 pub max_iter: usize,
69 /// The relative tolerance of the likelihood
70 pub tol: F,
71 /// The regularization of the fit
72 pub l2: F,
73 pub l1: F,
74 /// An initial guess. A sensible default is selected if this is not provided.
75 pub init_guess: Option<Array1<F>>,
76}
77
78impl<F> Default for FitOptions<F>
79where
80 F: Float,
81{
82 fn default() -> Self {
83 Self {
84 max_iter: 128,
85 // This tolerance is rather small, but it is used in the context of a
86 // fraction of the total likelihood.
87 tol: F::epsilon(),
88 l2: F::zero(),
89 l1: F::zero(),
90 init_guess: None,
91 }
92 }
93}