1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
#[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; use linfa::Float; use ndarray::{ArrayView1, CowArray, Ix1}; use super::Result; #[cfg_attr( feature = "serde", derive(Serialize, Deserialize), serde(crate = "serde_crate") )] /// Linear regression with both L1 and L2 regularization /// /// Configures and minimizes the following objective function: /// 1 / (2 * n_samples) * ||y - Xw||^2_2 /// + penalty * l1_ratio * ||w||_1 /// + 0.5 * penalty * (1 - l1_ratio) * ||w||^2_2 /// pub struct ElasticNetParams<F> { pub penalty: F, pub l1_ratio: F, pub with_intercept: bool, pub max_iterations: u32, pub tolerance: F, } ///AbsDiffEq + Float + FromPrimitive + ScalarOperand + NumAssignOps> /// Configure and fit a Elastic Net model impl<F: Float> ElasticNetParams<F> { /// Create default elastic net hyper parameters /// /// By default, an intercept will be fitted. To disable fitting an /// intercept, call `.with_intercept(false)` before calling `.fit()`. /// /// To additionally normalize the feature matrix before fitting, call /// `fit_intercept_and_normalize()` before calling `fit()`. The feature /// matrix will not be normalized by default. pub fn new() -> ElasticNetParams<F> { ElasticNetParams { penalty: F::one(), l1_ratio: F::from(0.5).unwrap(), with_intercept: true, max_iterations: 1000, tolerance: F::from(1e-4).unwrap(), } } /// Set the overall parameter penalty parameter of the elastic net. /// Use `l1_ratio` to configure how the penalty distributed to L1 and L2 /// regularization. pub fn penalty(mut self, penalty: F) -> Self { self.penalty = penalty; self } /// Set l1_ratio parameter of the elastic net. Controls how the parameter /// penalty is distributed to L1 and L2 regularization. /// Setting `l1_ratio` to 1.0 is equivalent to a "Lasso" penalization, /// setting it to 0.0 is equivalent to "Ridge" penalization. /// /// Defaults to `0.5` if not set /// /// `l1_ratio` must be between `0.0` and `1.0`. pub fn l1_ratio(mut self, l1_ratio: F) -> Self { if l1_ratio < F::zero() || l1_ratio > F::one() { panic!("Invalid value for l1_ratio, needs to be between 0.0 and 1.0"); } self.l1_ratio = l1_ratio; self } /// Configure the elastic net model to fit an intercept. /// Defaults to `true` if not set. pub fn with_intercept(mut self, with_intercept: bool) -> Self { self.with_intercept = with_intercept; self } /// Set the tolerance which is the minimum absolute change in any of the /// model parameters needed for the parameter optimization to continue. /// /// Defaults to `1e-4` if not set pub fn tolerance(mut self, tolerance: F) -> Self { self.tolerance = tolerance; self } /// Set the maximum number of iterations for the optimization routine. /// /// Defaults to `1000` if not set pub fn max_iterations(mut self, max_iterations: u32) -> Self { self.max_iterations = max_iterations; self } /// Compute the intercept as the mean of `y` and center `y` if an intercept should /// be used, use `0.0` as intercept and leave `y` unchanged otherwise. pub fn compute_intercept<'a>(&self, y: ArrayView1<'a, F>) -> (F, CowArray<'a, F, Ix1>) { if self.with_intercept { let y_mean = y.mean().unwrap(); let y_centered = &y - y_mean; (y_mean, y_centered.into()) } else { (F::zero(), y.into()) } } /// Validate the hyper parameters /// /// This function is called in `Self::fit` and validates all hyper parameters pub fn validate_params(&self) -> Result<()> { if self.penalty.is_negative() { let msg = format!("Penalty should be positive, but is {}", self.penalty); return Err(linfa::Error::Parameters(msg))?; } Ok(()) /*match self { ElasticNetParams { penalty, .. } if penalty.is_negative() => Err(linfa::Error::Parameters( format!("Penalty should be positive, but is {}", penalty), )), ElasticNetParams { tolerance, .. } if tolerance.is_negative() => { Err(linfa::Error::Parameters(format!( "Tolerance should be positive, but is {}", tolerance ))) } ElasticNetParams { l1_ratio, .. } if l1_ratio.is_negative() || l1_ratio > &F::one() => { Err(linfa::Error::Parameters(format!( "L1 ratio should be in range [0, 1], but is {}", l1_ratio ))) } _ => Ok(()), }*/ } }