1use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::HasCoefficients;
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2, Axis, ScalarOperand};
33use num_traits::{Float, FromPrimitive};
34
35#[derive(Debug, Clone)]
45pub struct Lasso<F> {
46 pub alpha: F,
49 pub max_iter: usize,
51 pub tol: F,
53 pub fit_intercept: bool,
55}
56
57impl<F: Float> Lasso<F> {
58 #[must_use]
63 pub fn new() -> Self {
64 Self {
65 alpha: F::one(),
66 max_iter: 1000,
67 tol: F::from(1e-4).unwrap(),
68 fit_intercept: true,
69 }
70 }
71
72 #[must_use]
74 pub fn with_alpha(mut self, alpha: F) -> Self {
75 self.alpha = alpha;
76 self
77 }
78
79 #[must_use]
81 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
82 self.max_iter = max_iter;
83 self
84 }
85
86 #[must_use]
88 pub fn with_tol(mut self, tol: F) -> Self {
89 self.tol = tol;
90 self
91 }
92
93 #[must_use]
95 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
96 self.fit_intercept = fit_intercept;
97 self
98 }
99}
100
101impl<F: Float> Default for Lasso<F> {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107#[derive(Debug, Clone)]
112pub struct FittedLasso<F> {
113 coefficients: Array1<F>,
115 intercept: F,
117}
118
119fn soft_threshold<F: Float>(x: F, threshold: F) -> F {
123 if x > threshold {
124 x - threshold
125 } else if x < -threshold {
126 x + threshold
127 } else {
128 F::zero()
129 }
130}
131
132impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
133 for Lasso<F>
134{
135 type Fitted = FittedLasso<F>;
136 type Error = FerroError;
137
138 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLasso<F>, FerroError> {
148 let (n_samples, n_features) = x.dim();
149
150 if n_samples != y.len() {
151 return Err(FerroError::ShapeMismatch {
152 expected: vec![n_samples],
153 actual: vec![y.len()],
154 context: "y length must match number of samples in X".into(),
155 });
156 }
157
158 if self.alpha < F::zero() {
159 return Err(FerroError::InvalidParameter {
160 name: "alpha".into(),
161 reason: "must be non-negative".into(),
162 });
163 }
164
165 if n_samples == 0 {
166 return Err(FerroError::InsufficientSamples {
167 required: 1,
168 actual: 0,
169 context: "Lasso requires at least one sample".into(),
170 });
171 }
172
173 let n_f = F::from(n_samples).unwrap();
174
175 let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
177 let x_mean = x
178 .mean_axis(Axis(0))
179 .ok_or_else(|| FerroError::NumericalInstability {
180 message: "failed to compute column means".into(),
181 })?;
182 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
183 message: "failed to compute target mean".into(),
184 })?;
185
186 let x_c = x - &x_mean;
187 let y_c = y - y_mean;
188 (x_c, y_c, Some(x_mean), Some(y_mean))
189 } else {
190 (x.clone(), y.clone(), None, None)
191 };
192
193 let col_norms: Vec<F> = (0..n_features)
195 .map(|j| {
196 let col = x_work.column(j);
197 col.dot(&col) / n_f
198 })
199 .collect();
200
201 let mut w = Array1::<F>::zeros(n_features);
203 let mut residual = y_work;
204
205 for _iter in 0..self.max_iter {
206 let mut max_change = F::zero();
207
208 for j in 0..n_features {
209 let col_j = x_work.column(j);
210
211 let w_old = w[j];
213 if w_old != F::zero() {
214 for i in 0..n_samples {
215 residual[i] = residual[i] + col_j[i] * w_old;
216 }
217 }
218
219 let rho = col_j.dot(&residual) / n_f;
221
222 let w_new = if col_norms[j] > F::zero() {
224 soft_threshold(rho, self.alpha) / col_norms[j]
225 } else {
226 F::zero()
227 };
228
229 if w_new != F::zero() {
231 for i in 0..n_samples {
232 residual[i] = residual[i] - col_j[i] * w_new;
233 }
234 }
235
236 let change = (w_new - w_old).abs();
237 if change > max_change {
238 max_change = change;
239 }
240
241 w[j] = w_new;
242 }
243
244 if max_change < self.tol {
246 let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
247 *ym - xm.dot(&w)
248 } else {
249 F::zero()
250 };
251
252 return Ok(FittedLasso {
253 coefficients: w,
254 intercept,
255 });
256 }
257 }
258
259 let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
261 *ym - xm.dot(&w)
262 } else {
263 F::zero()
264 };
265
266 Ok(FittedLasso {
267 coefficients: w,
268 intercept,
269 })
270 }
271}
272
273impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedLasso<F> {
274 type Output = Array1<F>;
275 type Error = FerroError;
276
277 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
286 let n_features = x.ncols();
287 if n_features != self.coefficients.len() {
288 return Err(FerroError::ShapeMismatch {
289 expected: vec![self.coefficients.len()],
290 actual: vec![n_features],
291 context: "number of features must match fitted model".into(),
292 });
293 }
294
295 let preds = x.dot(&self.coefficients) + self.intercept;
296 Ok(preds)
297 }
298}
299
300impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedLasso<F> {
301 fn coefficients(&self) -> &Array1<F> {
302 &self.coefficients
303 }
304
305 fn intercept(&self) -> F {
306 self.intercept
307 }
308}
309
310impl<F> PipelineEstimator<F> for Lasso<F>
312where
313 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
314{
315 fn fit_pipeline(
316 &self,
317 x: &Array2<F>,
318 y: &Array1<F>,
319 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
320 let fitted = self.fit(x, y)?;
321 Ok(Box::new(fitted))
322 }
323}
324
325impl<F> FittedPipelineEstimator<F> for FittedLasso<F>
326where
327 F: Float + ScalarOperand + Send + Sync + 'static,
328{
329 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
330 self.predict(x)
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use approx::assert_relative_eq;
338 use ndarray::array;
339
340 #[test]
341 fn test_soft_threshold() {
342 assert_relative_eq!(soft_threshold(5.0_f64, 1.0), 4.0);
343 assert_relative_eq!(soft_threshold(-5.0_f64, 1.0), -4.0);
344 assert_relative_eq!(soft_threshold(0.5_f64, 1.0), 0.0);
345 assert_relative_eq!(soft_threshold(-0.5_f64, 1.0), 0.0);
346 assert_relative_eq!(soft_threshold(0.0_f64, 1.0), 0.0);
347 }
348
349 #[test]
350 fn test_lasso_zero_alpha() {
351 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
353 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
354
355 let model = Lasso::<f64>::new().with_alpha(0.0);
356 let fitted = model.fit(&x, &y).unwrap();
357
358 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
359 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
360 }
361
362 #[test]
363 fn test_lasso_sparsity() {
364 let x = Array2::from_shape_vec(
366 (10, 3),
367 vec![
368 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
369 0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
370 ],
371 )
372 .unwrap();
373 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
374
375 let model = Lasso::<f64>::new().with_alpha(5.0);
376 let fitted = model.fit(&x, &y).unwrap();
377
378 assert_relative_eq!(fitted.coefficients()[1], 0.0, epsilon = 1e-10);
380 assert_relative_eq!(fitted.coefficients()[2], 0.0, epsilon = 1e-10);
381 }
382
383 #[test]
384 fn test_lasso_shrinks_coefficients() {
385 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
386 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
387
388 let model_low = Lasso::<f64>::new().with_alpha(0.01);
389 let model_high = Lasso::<f64>::new().with_alpha(1.0);
390
391 let fitted_low = model_low.fit(&x, &y).unwrap();
392 let fitted_high = model_high.fit(&x, &y).unwrap();
393
394 assert!(fitted_high.coefficients()[0].abs() <= fitted_low.coefficients()[0].abs());
395 }
396
397 #[test]
398 fn test_lasso_no_intercept() {
399 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
400 let y = array![2.0, 4.0, 6.0, 8.0];
401
402 let model = Lasso::<f64>::new()
403 .with_alpha(0.0)
404 .with_fit_intercept(false);
405 let fitted = model.fit(&x, &y).unwrap();
406
407 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
408 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
409 }
410
411 #[test]
412 fn test_lasso_negative_alpha() {
413 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
414 let y = array![1.0, 2.0, 3.0];
415
416 let model = Lasso::<f64>::new().with_alpha(-1.0);
417 let result = model.fit(&x, &y);
418 assert!(result.is_err());
419 }
420
421 #[test]
422 fn test_lasso_shape_mismatch() {
423 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
424 let y = array![1.0, 2.0];
425
426 let model = Lasso::<f64>::new();
427 let result = model.fit(&x, &y);
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn test_lasso_predict() {
433 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
434 let y = array![2.0, 4.0, 6.0, 8.0];
435
436 let model = Lasso::<f64>::new().with_alpha(0.01);
437 let fitted = model.fit(&x, &y).unwrap();
438 let preds = fitted.predict(&x).unwrap();
439 assert_eq!(preds.len(), 4);
440 }
441
442 #[test]
443 fn test_lasso_pipeline_integration() {
444 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
445 let y = array![3.0, 5.0, 7.0, 9.0];
446
447 let model = Lasso::<f64>::new().with_alpha(0.01);
448 let fitted = model.fit_pipeline(&x, &y).unwrap();
449 let preds = fitted.predict_pipeline(&x).unwrap();
450 assert_eq!(preds.len(), 4);
451 }
452
453 #[test]
454 fn test_lasso_has_coefficients() {
455 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
456 let y = array![1.0, 2.0, 3.0];
457
458 let model = Lasso::<f64>::new().with_alpha(0.1);
459 let fitted = model.fit(&x, &y).unwrap();
460
461 assert_eq!(fitted.coefficients().len(), 2);
462 }
463}