1use ferrolearn_core::error::FerroError;
40use ferrolearn_core::introspection::HasCoefficients;
41use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
42use ferrolearn_core::traits::{Fit, Predict};
43use ndarray::{Array1, Array2, Axis, ScalarOperand};
44use num_traits::{Float, FromPrimitive};
45
46#[derive(Debug, Clone)]
60pub struct Lars<F> {
61 pub n_nonzero_coefs: Option<usize>,
64 pub fit_intercept: bool,
66 _marker: core::marker::PhantomData<F>,
67}
68
69impl<F: Float> Lars<F> {
70 #[must_use]
74 pub fn new() -> Self {
75 Self {
76 n_nonzero_coefs: None,
77 fit_intercept: true,
78 _marker: core::marker::PhantomData,
79 }
80 }
81
82 #[must_use]
84 pub fn with_n_nonzero_coefs(mut self, n: usize) -> Self {
85 self.n_nonzero_coefs = Some(n);
86 self
87 }
88
89 #[must_use]
91 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
92 self.fit_intercept = fit_intercept;
93 self
94 }
95}
96
97impl<F: Float> Default for Lars<F> {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103#[derive(Debug, Clone)]
108pub struct FittedLars<F> {
109 coefficients: Array1<F>,
111 intercept: F,
113}
114
115#[derive(Debug, Clone)]
129pub struct LassoLars<F> {
130 pub alpha: F,
132 pub max_iter: usize,
134 pub fit_intercept: bool,
136}
137
138impl<F: Float> LassoLars<F> {
139 #[must_use]
143 pub fn new() -> Self {
144 Self {
145 alpha: F::one(),
146 max_iter: 500,
147 fit_intercept: true,
148 }
149 }
150
151 #[must_use]
153 pub fn with_alpha(mut self, alpha: F) -> Self {
154 self.alpha = alpha;
155 self
156 }
157
158 #[must_use]
160 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
161 self.max_iter = max_iter;
162 self
163 }
164
165 #[must_use]
167 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
168 self.fit_intercept = fit_intercept;
169 self
170 }
171}
172
173impl<F: Float> Default for LassoLars<F> {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179#[derive(Debug, Clone)]
183pub struct FittedLassoLars<F> {
184 coefficients: Array1<F>,
186 intercept: F,
188}
189
190fn ols_active<F: Float + FromPrimitive + 'static>(
198 x: &Array2<F>,
199 y: &Array1<F>,
200 active: &[usize],
201 n_features: usize,
202) -> Result<Array1<F>, FerroError> {
203 let n_samples = x.nrows();
204 let k = active.len();
205
206 let mut xa = Array2::<F>::zeros((n_samples, k));
208 for (col_idx, &j) in active.iter().enumerate() {
209 for i in 0..n_samples {
210 xa[[i, col_idx]] = x[[i, j]];
211 }
212 }
213
214 let xat = xa.t();
216 let xtx = xat.dot(&xa);
217 let xty = xat.dot(y);
218
219 let w_active = cholesky_solve(&xtx, &xty)
220 .or_else(|_| gaussian_solve(k, &xtx, &xty))?;
221
222 let mut w = Array1::<F>::zeros(n_features);
224 for (col_idx, &j) in active.iter().enumerate() {
225 w[j] = w_active[col_idx];
226 }
227 Ok(w)
228}
229
230fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
232 let n = a.nrows();
233 let mut l = Array2::<F>::zeros((n, n));
234
235 for i in 0..n {
236 for j in 0..=i {
237 let mut s = a[[i, j]];
238 for k in 0..j {
239 s = s - l[[i, k]] * l[[j, k]];
240 }
241 if i == j {
242 if s <= F::zero() {
243 return Err(FerroError::NumericalInstability {
244 message: "Cholesky: matrix not positive definite".into(),
245 });
246 }
247 l[[i, j]] = s.sqrt();
248 } else {
249 l[[i, j]] = s / l[[j, j]];
250 }
251 }
252 }
253
254 let mut z = Array1::<F>::zeros(n);
255 for i in 0..n {
256 let mut s = b[i];
257 for k in 0..i {
258 s = s - l[[i, k]] * z[k];
259 }
260 z[i] = s / l[[i, i]];
261 }
262
263 let mut x_sol = Array1::<F>::zeros(n);
264 for i in (0..n).rev() {
265 let mut s = z[i];
266 for k in (i + 1)..n {
267 s = s - l[[k, i]] * x_sol[k];
268 }
269 x_sol[i] = s / l[[i, i]];
270 }
271
272 Ok(x_sol)
273}
274
275fn gaussian_solve<F: Float>(
277 n: usize,
278 a: &Array2<F>,
279 b: &Array1<F>,
280) -> Result<Array1<F>, FerroError> {
281 let mut aug = Array2::<F>::zeros((n, n + 1));
282 for i in 0..n {
283 for j in 0..n {
284 aug[[i, j]] = a[[i, j]];
285 }
286 aug[[i, n]] = b[i];
287 }
288
289 for col in 0..n {
290 let mut max_val = aug[[col, col]].abs();
291 let mut max_row = col;
292 for row in (col + 1)..n {
293 let v = aug[[row, col]].abs();
294 if v > max_val {
295 max_val = v;
296 max_row = row;
297 }
298 }
299
300 if max_val < F::from(1e-12).unwrap_or_else(F::epsilon) {
301 return Err(FerroError::NumericalInstability {
302 message: "singular matrix in Gaussian elimination".into(),
303 });
304 }
305
306 if max_row != col {
307 for j in 0..=n {
308 let tmp = aug[[col, j]];
309 aug[[col, j]] = aug[[max_row, j]];
310 aug[[max_row, j]] = tmp;
311 }
312 }
313
314 let pivot = aug[[col, col]];
315 for row in (col + 1)..n {
316 let factor = aug[[row, col]] / pivot;
317 for j in col..=n {
318 let above = aug[[col, j]];
319 aug[[row, j]] = aug[[row, j]] - factor * above;
320 }
321 }
322 }
323
324 let mut x_sol = Array1::<F>::zeros(n);
325 for i in (0..n).rev() {
326 let mut s = aug[[i, n]];
327 for j in (i + 1)..n {
328 s = s - aug[[i, j]] * x_sol[j];
329 }
330 if aug[[i, i]].abs() < F::from(1e-12).unwrap_or_else(F::epsilon) {
331 return Err(FerroError::NumericalInstability {
332 message: "near-zero pivot in back substitution".into(),
333 });
334 }
335 x_sol[i] = s / aug[[i, i]];
336 }
337
338 Ok(x_sol)
339}
340
341type CentredData<F> = (Array2<F>, Array1<F>, Option<Array1<F>>, Option<F>);
343
344fn center_data<F: Float + FromPrimitive + ScalarOperand + 'static>(
346 x: &Array2<F>,
347 y: &Array1<F>,
348 fit_intercept: bool,
349) -> Result<CentredData<F>, FerroError> {
350 if fit_intercept {
351 let x_mean = x
352 .mean_axis(Axis(0))
353 .ok_or_else(|| FerroError::NumericalInstability {
354 message: "failed to compute column means".into(),
355 })?;
356 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
357 message: "failed to compute target mean".into(),
358 })?;
359 let x_c = x - &x_mean;
360 let y_c = y - y_mean;
361 Ok((x_c, y_c, Some(x_mean), Some(y_mean)))
362 } else {
363 Ok((x.clone(), y.clone(), None, None))
364 }
365}
366
367fn compute_intercept<F: Float + 'static>(
369 x_mean: &Option<Array1<F>>,
370 y_mean: &Option<F>,
371 w: &Array1<F>,
372) -> F {
373 if let (Some(xm), Some(ym)) = (x_mean, y_mean) {
374 *ym - xm.dot(w)
375 } else {
376 F::zero()
377 }
378}
379
380fn validate_input<F: Float>(
382 x: &Array2<F>,
383 y: &Array1<F>,
384 name: &str,
385) -> Result<(usize, usize), FerroError> {
386 let (n_samples, n_features) = x.dim();
387
388 if n_samples != y.len() {
389 return Err(FerroError::ShapeMismatch {
390 expected: vec![n_samples],
391 actual: vec![y.len()],
392 context: "y length must match number of samples in X".into(),
393 });
394 }
395
396 if n_samples == 0 {
397 return Err(FerroError::InsufficientSamples {
398 required: 1,
399 actual: 0,
400 context: format!("{name} requires at least one sample"),
401 });
402 }
403
404 Ok((n_samples, n_features))
405}
406
407impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
412 for Lars<F>
413{
414 type Fitted = FittedLars<F>;
415 type Error = FerroError;
416
417 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLars<F>, FerroError> {
428 let (_n_samples, n_features) = validate_input(x, y, "Lars")?;
429
430 let max_active = self.n_nonzero_coefs.unwrap_or(n_features);
431 if max_active > n_features {
432 return Err(FerroError::InvalidParameter {
433 name: "n_nonzero_coefs".into(),
434 reason: format!(
435 "cannot exceed number of features ({n_features})"
436 ),
437 });
438 }
439
440 let (x_work, y_work, x_mean, y_mean) =
441 center_data(x, y, self.fit_intercept)?;
442
443 let mut active: Vec<usize> = Vec::with_capacity(max_active);
444 let mut in_active = vec![false; n_features];
445 let mut w = Array1::<F>::zeros(n_features);
446 let mut residual = y_work.clone();
447
448 for _step in 0..max_active {
449 let mut best_j = None;
451 let mut best_corr = F::zero();
452 for (j, &is_active) in in_active.iter().enumerate() {
453 if is_active {
454 continue;
455 }
456 let corr = x_work.column(j).dot(&residual).abs();
457 if corr > best_corr {
458 best_corr = corr;
459 best_j = Some(j);
460 }
461 }
462
463 let j = match best_j {
464 Some(j) => j,
465 None => break, };
467
468 active.push(j);
469 in_active[j] = true;
470
471 w = ols_active(&x_work, &y_work, &active, n_features)?;
473
474 residual = &y_work - x_work.dot(&w);
476 }
477
478 let intercept = compute_intercept(&x_mean, &y_mean, &w);
479
480 Ok(FittedLars {
481 coefficients: w,
482 intercept,
483 })
484 }
485}
486
487impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
492 for LassoLars<F>
493{
494 type Fitted = FittedLassoLars<F>;
495 type Error = FerroError;
496
497 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLassoLars<F>, FerroError> {
510 let (n_samples, n_features) = validate_input(x, y, "LassoLars")?;
511
512 if self.alpha < F::zero() {
513 return Err(FerroError::InvalidParameter {
514 name: "alpha".into(),
515 reason: "must be non-negative".into(),
516 });
517 }
518
519 let n_f = F::from(n_samples).unwrap();
520 let (x_work, y_work, x_mean, y_mean) =
521 center_data(x, y, self.fit_intercept)?;
522
523 let mut active: Vec<usize> = Vec::new();
524 let mut in_active = vec![false; n_features];
525 let mut w = Array1::<F>::zeros(n_features);
526 let mut residual = y_work.clone();
527
528 for _step in 0..self.max_iter {
529 let mut best_j = None;
531 let mut best_corr = F::zero();
532 for (j, &is_active) in in_active.iter().enumerate() {
533 if is_active {
534 continue;
535 }
536 let corr = x_work.column(j).dot(&residual).abs() / n_f;
537 if corr > best_corr {
538 best_corr = corr;
539 best_j = Some(j);
540 }
541 }
542
543 if best_corr <= self.alpha && !active.is_empty() {
545 break;
546 }
547
548 if let Some(j) = best_j {
550 active.push(j);
551 in_active[j] = true;
552 } else {
553 break;
554 }
555
556 let w_new = ols_active(&x_work, &y_work, &active, n_features)?;
558
559 let mut dropped = false;
561 for idx in (0..active.len()).rev() {
562 let feat = active[idx];
563 if w[feat] != F::zero()
565 && w_new[feat].signum() != w[feat].signum()
566 {
567 active.remove(idx);
568 in_active[feat] = false;
569 dropped = true;
570 }
571 }
572
573 if dropped && !active.is_empty() {
574 w = ols_active(&x_work, &y_work, &active, n_features)?;
576 } else {
577 w = w_new;
578 }
579
580 residual = &y_work - x_work.dot(&w);
582 }
583
584 let intercept = compute_intercept(&x_mean, &y_mean, &w);
585
586 Ok(FittedLassoLars {
587 coefficients: w,
588 intercept,
589 })
590 }
591}
592
593impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedLars<F> {
598 type Output = Array1<F>;
599 type Error = FerroError;
600
601 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
610 if x.ncols() != self.coefficients.len() {
611 return Err(FerroError::ShapeMismatch {
612 expected: vec![self.coefficients.len()],
613 actual: vec![x.ncols()],
614 context: "number of features must match fitted model".into(),
615 });
616 }
617 Ok(x.dot(&self.coefficients) + self.intercept)
618 }
619}
620
621impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedLars<F> {
622 fn coefficients(&self) -> &Array1<F> {
623 &self.coefficients
624 }
625
626 fn intercept(&self) -> F {
627 self.intercept
628 }
629}
630
631impl<F> PipelineEstimator<F> for Lars<F>
632where
633 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
634{
635 fn fit_pipeline(
636 &self,
637 x: &Array2<F>,
638 y: &Array1<F>,
639 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
640 let fitted = self.fit(x, y)?;
641 Ok(Box::new(fitted))
642 }
643}
644
645impl<F> FittedPipelineEstimator<F> for FittedLars<F>
646where
647 F: Float + ScalarOperand + Send + Sync + 'static,
648{
649 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
650 self.predict(x)
651 }
652}
653
654impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedLassoLars<F> {
659 type Output = Array1<F>;
660 type Error = FerroError;
661
662 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
671 if x.ncols() != self.coefficients.len() {
672 return Err(FerroError::ShapeMismatch {
673 expected: vec![self.coefficients.len()],
674 actual: vec![x.ncols()],
675 context: "number of features must match fitted model".into(),
676 });
677 }
678 Ok(x.dot(&self.coefficients) + self.intercept)
679 }
680}
681
682impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedLassoLars<F> {
683 fn coefficients(&self) -> &Array1<F> {
684 &self.coefficients
685 }
686
687 fn intercept(&self) -> F {
688 self.intercept
689 }
690}
691
692impl<F> PipelineEstimator<F> for LassoLars<F>
693where
694 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
695{
696 fn fit_pipeline(
697 &self,
698 x: &Array2<F>,
699 y: &Array1<F>,
700 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
701 let fitted = self.fit(x, y)?;
702 Ok(Box::new(fitted))
703 }
704}
705
706impl<F> FittedPipelineEstimator<F> for FittedLassoLars<F>
707where
708 F: Float + ScalarOperand + Send + Sync + 'static,
709{
710 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
711 self.predict(x)
712 }
713}
714
715#[cfg(test)]
720mod tests {
721 use super::*;
722 use approx::assert_relative_eq;
723 use ndarray::array;
724
725 #[test]
728 fn test_lars_defaults() {
729 let m = Lars::<f64>::new();
730 assert!(m.n_nonzero_coefs.is_none());
731 assert!(m.fit_intercept);
732 }
733
734 #[test]
735 fn test_lars_builder() {
736 let m = Lars::<f64>::new()
737 .with_n_nonzero_coefs(3)
738 .with_fit_intercept(false);
739 assert_eq!(m.n_nonzero_coefs, Some(3));
740 assert!(!m.fit_intercept);
741 }
742
743 #[test]
744 fn test_lars_simple_linear() {
745 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
746 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
747
748 let fitted = Lars::<f64>::new().fit(&x, &y).unwrap();
749 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-6);
750 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-6);
751 }
752
753 #[test]
754 fn test_lars_sparsity() {
755 let x = Array2::from_shape_vec(
757 (10, 3),
758 vec![
759 1.0, 0.1, 0.01, 2.0, 0.2, 0.02, 3.0, 0.3, 0.03, 4.0, 0.4, 0.04,
760 5.0, 0.5, 0.05, 6.0, 0.6, 0.06, 7.0, 0.7, 0.07, 8.0, 0.8, 0.08,
761 9.0, 0.9, 0.09, 10.0, 1.0, 0.10,
762 ],
763 )
764 .unwrap();
765 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
766
767 let fitted = Lars::<f64>::new().with_n_nonzero_coefs(1).fit(&x, &y).unwrap();
768 let nonzero = fitted
769 .coefficients()
770 .iter()
771 .filter(|&&c| c.abs() > 1e-10)
772 .count();
773 assert_eq!(nonzero, 1);
774 }
775
776 #[test]
777 fn test_lars_predict() {
778 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
779 let y = array![2.0, 4.0, 6.0, 8.0];
780
781 let fitted = Lars::<f64>::new().fit(&x, &y).unwrap();
782 let preds = fitted.predict(&x).unwrap();
783 assert_eq!(preds.len(), 4);
784 }
785
786 #[test]
787 fn test_lars_shape_mismatch() {
788 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
789 let y = array![1.0, 2.0];
790 assert!(Lars::<f64>::new().fit(&x, &y).is_err());
791 }
792
793 #[test]
794 fn test_lars_predict_feature_mismatch() {
795 let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
796 let y = array![1.0, 2.0, 3.0];
797 let fitted = Lars::<f64>::new().fit(&x, &y).unwrap();
798 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
799 assert!(fitted.predict(&x_bad).is_err());
800 }
801
802 #[test]
803 fn test_lars_n_nonzero_exceeds_features() {
804 let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
805 let y = array![1.0, 2.0, 3.0];
806 assert!(Lars::<f64>::new().with_n_nonzero_coefs(5).fit(&x, &y).is_err());
807 }
808
809 #[test]
810 fn test_lars_no_intercept() {
811 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
812 let y = array![2.0, 4.0, 6.0, 8.0];
813
814 let fitted = Lars::<f64>::new()
815 .with_fit_intercept(false)
816 .fit(&x, &y)
817 .unwrap();
818 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
819 }
820
821 #[test]
822 fn test_lars_has_coefficients() {
823 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
824 let y = array![1.0, 2.0, 3.0];
825 let fitted = Lars::<f64>::new().fit(&x, &y).unwrap();
826 assert_eq!(fitted.coefficients().len(), 2);
827 }
828
829 #[test]
830 fn test_lars_pipeline() {
831 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
832 let y = array![3.0, 5.0, 7.0, 9.0];
833 let model = Lars::<f64>::new();
834 let fitted = model.fit_pipeline(&x, &y).unwrap();
835 let preds = fitted.predict_pipeline(&x).unwrap();
836 assert_eq!(preds.len(), 4);
837 }
838
839 #[test]
842 fn test_lasso_lars_defaults() {
843 let m = LassoLars::<f64>::new();
844 assert_relative_eq!(m.alpha, 1.0);
845 assert_eq!(m.max_iter, 500);
846 assert!(m.fit_intercept);
847 }
848
849 #[test]
850 fn test_lasso_lars_builder() {
851 let m = LassoLars::<f64>::new()
852 .with_alpha(0.5)
853 .with_max_iter(100)
854 .with_fit_intercept(false);
855 assert_relative_eq!(m.alpha, 0.5);
856 assert_eq!(m.max_iter, 100);
857 assert!(!m.fit_intercept);
858 }
859
860 #[test]
861 fn test_lasso_lars_simple() {
862 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
863 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
864
865 let fitted = LassoLars::<f64>::new()
866 .with_alpha(0.0)
867 .fit(&x, &y)
868 .unwrap();
869 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.1);
870 }
871
872 #[test]
873 fn test_lasso_lars_sparsity() {
874 let x = Array2::from_shape_vec(
876 (10, 3),
877 vec![
878 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
879 5.0, 0.0, 0.0, 6.0, 0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0,
880 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
881 ],
882 )
883 .unwrap();
884 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
885
886 let fitted = LassoLars::<f64>::new()
887 .with_alpha(5.0)
888 .fit(&x, &y)
889 .unwrap();
890 assert_relative_eq!(fitted.coefficients()[1], 0.0, epsilon = 1e-10);
892 assert_relative_eq!(fitted.coefficients()[2], 0.0, epsilon = 1e-10);
893 }
894
895 #[test]
896 fn test_lasso_lars_negative_alpha() {
897 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
898 let y = array![1.0, 2.0, 3.0];
899 assert!(LassoLars::<f64>::new().with_alpha(-1.0).fit(&x, &y).is_err());
900 }
901
902 #[test]
903 fn test_lasso_lars_shape_mismatch() {
904 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
905 let y = array![1.0, 2.0];
906 assert!(LassoLars::<f64>::new().fit(&x, &y).is_err());
907 }
908
909 #[test]
910 fn test_lasso_lars_predict() {
911 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
912 let y = array![2.0, 4.0, 6.0, 8.0];
913 let fitted = LassoLars::<f64>::new().with_alpha(0.01).fit(&x, &y).unwrap();
914 let preds = fitted.predict(&x).unwrap();
915 assert_eq!(preds.len(), 4);
916 }
917
918 #[test]
919 fn test_lasso_lars_has_coefficients() {
920 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
921 let y = array![1.0, 2.0, 3.0];
922 let fitted = LassoLars::<f64>::new().with_alpha(0.01).fit(&x, &y).unwrap();
923 assert_eq!(fitted.coefficients().len(), 2);
924 }
925
926 #[test]
927 fn test_lasso_lars_pipeline() {
928 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
929 let y = array![3.0, 5.0, 7.0, 9.0];
930 let model = LassoLars::<f64>::new().with_alpha(0.01);
931 let fitted = model.fit_pipeline(&x, &y).unwrap();
932 let preds = fitted.predict_pipeline(&x).unwrap();
933 assert_eq!(preds.len(), 4);
934 }
935}