1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use linfa::{Float, ParamGuard};

use crate::error::ElasticNetError;

use super::Result;

#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
/// A verified hyper-parameter set ready for the estimation of a ElasticNet regression model
///
/// See [`ElasticNetParams`](crate::ElasticNetParams) for more informations.
#[derive(Clone, Debug, PartialEq)]
pub struct ElasticNetValidParams<F> {
    penalty: F,
    l1_ratio: F,
    with_intercept: bool,
    max_iterations: u32,
    tolerance: F,
}

impl<F: Float> ElasticNetValidParams<F> {
    pub fn penalty(&self) -> F {
        self.penalty
    }

    pub fn l1_ratio(&self) -> F {
        self.l1_ratio
    }

    pub fn with_intercept(&self) -> bool {
        self.with_intercept
    }

    pub fn max_iterations(&self) -> u32 {
        self.max_iterations
    }

    pub fn tolerance(&self) -> F {
        self.tolerance
    }
}

/// A hyper-parameter set during construction
///
/// Configures and minimizes the following objective function:
/// ```ignore
/// 1 / (2 * n_samples) * ||y - Xw||^2_2
///     + penalty * l1_ratio * ||w||_1
///     + 0.5 * penalty * (1 - l1_ratio) * ||w||^2_2
/// ```
///
/// The parameter set can be verified into a
/// [`ElasticNetValidParams`](crate::hyperparams::ElasticNetValidParams) by calling
/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
/// [Fit::fit](linfa::traits::Fit::fit) which implicitely verifies the parameter set prior to the
/// model estimation and forwards any error.
///
/// # Parameters
/// | Name | Default | Purpose | Range |
/// | :--- | :--- | :---| :--- |
/// | [penalty](Self::penalty) | `1.0` | Overall parameter penalty | `[0, inf)` |
/// | [l1_ratio](Self::l1_ratio) | `0.5` | Distribution of penalty to L1 and L2 regularizations | `[0.0, 1.0]` |
/// | [with_intercept](Self::with_intercept) | `true` | Enable intercept | `false`, `true` |
/// | [tolerance](Self::tolerance) | `1e-4` | Absolute change of any of the parameters | `(0, inf)` |
/// | [max_iterations](Self::max_iterations) | `1000` | Maximum number of iterations | `[1, inf)` |
///
/// # Errors
///
/// The following errors can come from invalid hyper-parameters:
///
/// Returns [`InvalidPenalty`](ElasticNetError::InvalidPenalty) if the penalty is negative.
///
/// Returns [`InvalidL1Ratio`](ElasticNetError::InvalidL1Ratio) if the L1 ratio is not in unit.
/// range
///
/// Returns [`InvalidTolerance`](ElasticNetError::InvalidTolerance) if the tolerance is negative.
///
/// # Example
///
/// ```rust
/// use linfa_elasticnet::{ElasticNetParams, ElasticNetError};
/// use linfa::prelude::*;
/// use ndarray::array;
///
/// let ds = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![3.0, 2.0]);
///
/// // create a new parameter set with penalty equals `1e-5`
/// let unchecked_params = ElasticNetParams::new()
///     .penalty(1e-5);
///
/// // fit model with unchecked parameter set
/// let model = unchecked_params.fit(&ds)?;
///
/// // transform into a verified parameter set
/// let checked_params = unchecked_params.check()?;
///
/// // Regenerate model with the verified parameters, this only returns
/// // errors originating from the fitting process
/// let model = checked_params.fit(&ds)?;
/// # Ok::<(), ElasticNetError>(())
/// ```
#[derive(Clone, Debug, PartialEq)]
pub struct ElasticNetParams<F>(ElasticNetValidParams<F>);

impl<F: Float> Default for ElasticNetParams<F> {
    fn default() -> Self {
        Self::new()
    }
}

/// Configure and fit a Elastic Net model
impl<F: Float> ElasticNetParams<F> {
    /// Create default elastic net hyper parameters
    ///
    /// By default, an intercept will be fitted. To disable fitting an
    /// intercept, call `.with_intercept(false)` before calling `.fit()`.
    ///
    /// To additionally normalize the feature matrix before fitting, call
    /// `fit_intercept_and_normalize()` before calling `fit()`. The feature
    /// matrix will not be normalized by default.
    pub fn new() -> ElasticNetParams<F> {
        Self(ElasticNetValidParams {
            penalty: F::one(),
            l1_ratio: F::cast(0.5),
            with_intercept: true,
            max_iterations: 1000,
            tolerance: F::cast(1e-4),
        })
    }

    /// Set the overall parameter penalty parameter of the elastic net.
    /// Use `l1_ratio` to configure how the penalty distributed to L1 and L2
    /// regularization.
    pub fn penalty(mut self, penalty: F) -> Self {
        self.0.penalty = penalty;
        self
    }

    /// Set l1_ratio parameter of the elastic net. Controls how the parameter
    /// penalty is distributed to L1 and L2 regularization.
    /// Setting `l1_ratio` to 1.0 is equivalent to a "Lasso" penalization,
    /// setting it to 0.0 is equivalent to "Ridge" penalization.
    ///
    /// Defaults to `0.5` if not set
    ///
    /// `l1_ratio` must be between `0.0` and `1.0`.
    pub fn l1_ratio(mut self, l1_ratio: F) -> Self {
        self.0.l1_ratio = l1_ratio;
        self
    }

    /// Configure the elastic net model to fit an intercept.
    /// Defaults to `true` if not set.
    pub fn with_intercept(mut self, with_intercept: bool) -> Self {
        self.0.with_intercept = with_intercept;
        self
    }

    /// Set the tolerance which is the minimum absolute change in any of the
    /// model parameters needed for the parameter optimization to continue.
    ///
    /// Defaults to `1e-4` if not set
    pub fn tolerance(mut self, tolerance: F) -> Self {
        self.0.tolerance = tolerance;
        self
    }

    /// Set the maximum number of iterations for the optimization routine.
    ///
    /// Defaults to `1000` if not set
    pub fn max_iterations(mut self, max_iterations: u32) -> Self {
        self.0.max_iterations = max_iterations;
        self
    }
}

impl<F: Float> ParamGuard for ElasticNetParams<F> {
    type Checked = ElasticNetValidParams<F>;
    type Error = ElasticNetError;

    /// Validate the hyper parameters
    fn check_ref(&self) -> Result<&Self::Checked> {
        if self.0.penalty.is_negative() {
            Err(ElasticNetError::InvalidPenalty(
                self.0.penalty.to_f32().unwrap(),
            ))
        } else if !(F::zero()..=F::one()).contains(&self.0.l1_ratio) {
            Err(ElasticNetError::InvalidL1Ratio(
                self.0.l1_ratio.to_f32().unwrap(),
            ))
        } else if self.0.tolerance.is_negative() {
            Err(ElasticNetError::InvalidTolerance(
                self.0.tolerance.to_f32().unwrap(),
            ))
        } else {
            Ok(&self.0)
        }
    }

    fn check(self) -> Result<Self::Checked> {
        self.check_ref()?;
        Ok(self.0)
    }
}