1use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{s, Array};
6use scirs2_linalg::compat::ArrayLinalgExt;
7use sklears_core::{
9 error::{validate, Result, SklearsError},
10 traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11 types::{Array1, Array2, Float},
12};
13
14use crate::{Penalty, Solver};
15
16#[cfg(feature = "coordinate-descent")]
17use crate::coordinate_descent::CoordinateDescentSolver;
18
19#[cfg(feature = "coordinate-descent")]
20use crate::coordinate_descent::ValidationInfo;
21
22#[cfg(feature = "early-stopping")]
23use crate::early_stopping::EarlyStoppingConfig;
24
25#[derive(Debug, Clone)]
27pub struct LinearRegressionConfig {
28 pub fit_intercept: bool,
30 pub penalty: Penalty,
32 pub solver: Solver,
34 pub max_iter: usize,
36 pub tol: f64,
38 pub warm_start: bool,
40 #[cfg(feature = "gpu")]
42 pub use_gpu: bool,
43 #[cfg(feature = "gpu")]
45 pub gpu_min_size: usize,
46}
47
48impl Default for LinearRegressionConfig {
49 fn default() -> Self {
50 Self {
51 fit_intercept: true,
52 penalty: Penalty::None,
53 solver: Solver::Auto,
54 max_iter: 1000,
55 tol: 1e-4,
56 warm_start: false,
57 #[cfg(feature = "gpu")]
58 use_gpu: true,
59 #[cfg(feature = "gpu")]
60 gpu_min_size: 1000,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct LinearRegression<State = Untrained> {
68 config: LinearRegressionConfig,
69 state: PhantomData<State>,
70 coef_: Option<Array1<Float>>,
72 intercept_: Option<Float>,
73 n_features_: Option<usize>,
74}
75
76impl LinearRegression<Untrained> {
77 pub fn new() -> Self {
79 Self {
80 config: LinearRegressionConfig::default(),
81 state: PhantomData,
82 coef_: None,
83 intercept_: None,
84 n_features_: None,
85 }
86 }
87
88 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
90 self.config.fit_intercept = fit_intercept;
91 self
92 }
93
94 pub fn regularization(mut self, alpha: f64) -> Self {
96 self.config.penalty = Penalty::L2(alpha);
97 self
98 }
99
100 pub fn lasso(alpha: f64) -> Self {
102 Self::new()
103 .penalty(Penalty::L1(alpha))
104 .solver(Solver::CoordinateDescent)
105 }
106
107 pub fn elastic_net(alpha: f64, l1_ratio: f64) -> Self {
109 Self::new()
110 .penalty(Penalty::ElasticNet { l1_ratio, alpha })
111 .solver(Solver::CoordinateDescent)
112 }
113
114 pub fn penalty(mut self, penalty: Penalty) -> Self {
116 self.config.penalty = penalty;
117 self
118 }
119
120 pub fn solver(mut self, solver: Solver) -> Self {
122 self.config.solver = solver;
123 self
124 }
125
126 pub fn max_iter(mut self, max_iter: usize) -> Self {
128 self.config.max_iter = max_iter;
129 self
130 }
131
132 pub fn warm_start(mut self, warm_start: bool) -> Self {
134 self.config.warm_start = warm_start;
135 self
136 }
137
138 #[cfg(feature = "gpu")]
140 pub fn use_gpu(mut self, use_gpu: bool) -> Self {
141 self.config.use_gpu = use_gpu;
142 self
143 }
144
145 #[cfg(feature = "gpu")]
147 pub fn gpu_min_size(mut self, min_size: usize) -> Self {
148 self.config.gpu_min_size = min_size;
149 self
150 }
151}
152
153impl Default for LinearRegression<Untrained> {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl Estimator for LinearRegression<Untrained> {
160 type Config = LinearRegressionConfig;
161 type Error = SklearsError;
162 type Float = Float;
163
164 fn config(&self) -> &Self::Config {
165 &self.config
166 }
167}
168
169impl Fit<Array2<Float>, Array1<Float>> for LinearRegression<Untrained> {
170 type Fitted = LinearRegression<Trained>;
171
172 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
173 validate::check_consistent_length(x, y)?;
175
176 let n_samples = x.nrows();
177 let n_features = x.ncols();
178
179 let (x_with_intercept, n_params) = if self.config.fit_intercept {
181 let mut x_new = Array::ones((n_samples, n_features + 1));
182 x_new.slice_mut(s![.., 1..]).assign(x);
183 (x_new, n_features + 1)
184 } else {
185 (x.clone(), n_features)
186 };
187
188 let params = match self.config.penalty {
190 Penalty::None => {
191 #[cfg(feature = "gpu")]
193 if self.config.use_gpu && n_samples * n_features >= self.config.gpu_min_size {
194 match self.solve_ols_gpu(&x_with_intercept, y) {
196 Ok(params) => params,
197 Err(_) => {
198 self.solve_ols_cpu(&x_with_intercept, y)?
200 }
201 }
202 } else {
203 self.solve_ols_cpu(&x_with_intercept, y)?
204 }
205
206 #[cfg(not(feature = "gpu"))]
207 self.solve_ols_cpu(&x_with_intercept, y)?
208 }
209 Penalty::L2(alpha) => {
210 let xtx = x_with_intercept.t().dot(&x_with_intercept);
213 let xty = x_with_intercept.t().dot(y);
214
215 let mut regularized = xtx.clone();
217 let start_idx = if self.config.fit_intercept { 1 } else { 0 };
218 for i in start_idx..n_params {
219 regularized[[i, i]] += alpha;
220 }
221
222 regularized.solve(&xty).map_err(|e| {
223 SklearsError::NumericalError(format!("Failed to solve ridge regression: {}", e))
224 })?
225 }
226 Penalty::L1(alpha) => {
227 #[cfg(feature = "coordinate-descent")]
229 {
230 let cd_solver = CoordinateDescentSolver {
231 max_iter: self.config.max_iter,
232 tol: self.config.tol,
233 cyclic: true,
234 #[cfg(feature = "early-stopping")]
235 early_stopping_config: None,
236 };
237
238 let (coef, intercept) = cd_solver
239 .solve_lasso(x, y, alpha, self.config.fit_intercept)
240 .map_err(|e| {
241 SklearsError::NumericalError(format!(
242 "Coordinate descent failed: {}",
243 e
244 ))
245 })?;
246
247 if self.config.fit_intercept {
248 let mut params = Array::zeros(coef.len() + 1);
250 params[0] = intercept.unwrap_or(0.0);
251 params.slice_mut(s![1..]).assign(&coef);
252 params
253 } else {
254 coef
255 }
256 }
257 #[cfg(not(feature = "coordinate-descent"))]
258 {
259 return Err(SklearsError::InvalidParameter {
260 name: "penalty".to_string(),
261 reason:
262 "L1 regularization (Lasso) requires the 'coordinate-descent' feature"
263 .to_string(),
264 });
265 }
266 }
267 Penalty::ElasticNet { l1_ratio, alpha } => {
268 #[cfg(feature = "coordinate-descent")]
270 {
271 let cd_solver = CoordinateDescentSolver {
272 max_iter: self.config.max_iter,
273 tol: self.config.tol,
274 cyclic: true,
275 #[cfg(feature = "early-stopping")]
276 early_stopping_config: None,
277 };
278
279 let (coef, intercept) = cd_solver
280 .solve_elastic_net(x, y, alpha, l1_ratio, self.config.fit_intercept)
281 .map_err(|e| {
282 SklearsError::NumericalError(format!(
283 "Coordinate descent failed: {}",
284 e
285 ))
286 })?;
287
288 if self.config.fit_intercept {
289 let mut params = Array::zeros(coef.len() + 1);
291 params[0] = intercept.unwrap_or(0.0);
292 params.slice_mut(s![1..]).assign(&coef);
293 params
294 } else {
295 coef
296 }
297 }
298 #[cfg(not(feature = "coordinate-descent"))]
299 {
300 return Err(SklearsError::InvalidParameter {
301 name: "penalty".to_string(),
302 reason:
303 "ElasticNet regularization requires the 'coordinate-descent' feature"
304 .to_string(),
305 });
306 }
307 }
308 };
309
310 let (coef_, intercept_) = if self.config.fit_intercept {
312 let intercept = params[0];
313 let coef = params.slice(s![1..]).to_owned();
314 (coef, Some(intercept))
315 } else {
316 (params, None)
317 };
318
319 Ok(LinearRegression {
320 config: self.config,
321 state: PhantomData,
322 coef_: Some(coef_),
323 intercept_,
324 n_features_: Some(n_features),
325 })
326 }
327}
328
329impl LinearRegression<Untrained> {
330 fn solve_ols_cpu(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
332 let xtx = x.t().dot(x);
335 let xty = x.t().dot(y);
336
337 xtx.solve(&xty).map_err(|e| {
339 SklearsError::NumericalError(format!("Failed to solve linear system: {}", e))
340 })
341 }
342
343 #[cfg(feature = "gpu")]
345 fn solve_ols_gpu(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
346 use crate::gpu_acceleration::{GpuConfig, GpuLinearOps};
347
348 let gpu_config = GpuConfig {
350 device_id: 0,
351 use_pinned_memory: true,
352 min_problem_size: self.config.gpu_min_size,
353 ..Default::default()
354 };
355
356 let gpu_ops = GpuLinearOps::new(gpu_config).map_err(|e| {
357 SklearsError::NumericalError(format!("Failed to initialize GPU operations: {}", e))
358 })?;
359
360 if !gpu_ops.is_gpu_available() {
362 return Err(SklearsError::NumericalError(
363 "GPU not available, falling back to CPU".to_string(),
364 ));
365 }
366
367 let xt = gpu_ops.matrix_transpose(x)?;
369 let xtx = gpu_ops.matrix_multiply(&xt, x)?;
370
371 let xty = gpu_ops.matrix_vector_multiply(&xt, y)?;
373
374 gpu_ops.solve_linear_system(&xtx, &xty)
376 }
377
378 pub fn fit_with_warm_start(
382 self,
383 x: &Array2<Float>,
384 y: &Array1<Float>,
385 initial_coef: Option<&Array1<Float>>,
386 initial_intercept: Option<Float>,
387 ) -> Result<LinearRegression<Trained>> {
388 validate::check_consistent_length(x, y)?;
390
391 let n_features = x.ncols();
392
393 let params: Array1<Float> = match self.config.penalty {
395 Penalty::L1(_)
396 | Penalty::L2(_)
397 | Penalty::ElasticNet {
398 alpha: _,
399 l1_ratio: _,
400 } => {
401 #[cfg(feature = "coordinate-descent")]
402 {
403 let (alpha_val, l1_ratio) = match self.config.penalty {
404 Penalty::L1(alpha) => (alpha, 1.0),
405 Penalty::L2(alpha) => (alpha, 0.0),
406 Penalty::ElasticNet { alpha, l1_ratio } => (alpha, l1_ratio),
407 _ => unreachable!(),
408 };
409
410 let cd_solver = CoordinateDescentSolver {
411 max_iter: self.config.max_iter,
412 tol: self.config.tol,
413 cyclic: true,
414 #[cfg(feature = "early-stopping")]
415 early_stopping_config: None,
416 };
417
418 let (coef, intercept) = cd_solver
419 .solve_elastic_net_with_warm_start(
420 x,
421 y,
422 alpha_val,
423 l1_ratio,
424 self.config.fit_intercept,
425 initial_coef,
426 initial_intercept,
427 )
428 .map_err(|e| {
429 SklearsError::NumericalError(format!(
430 "Coordinate descent failed: {}",
431 e
432 ))
433 })?;
434
435 if self.config.fit_intercept {
436 let mut params = Array::zeros(coef.len() + 1);
438 params[0] = intercept.unwrap_or(0.0);
439 params.slice_mut(s![1..]).assign(&coef);
440 params
441 } else {
442 coef
443 }
444 }
445 #[cfg(not(feature = "coordinate-descent"))]
446 {
447 return Err(SklearsError::InvalidParameter {
448 name: "penalty".to_string(),
449 reason: "Warm start requires the 'coordinate-descent' feature".to_string(),
450 });
451 }
452 }
453 Penalty::None => {
454 return Err(SklearsError::InvalidParameter {
455 name: "penalty".to_string(),
456 reason:
457 "Warm start only supported for regularized methods (L1, L2, ElasticNet)"
458 .to_string(),
459 });
460 }
461 };
462
463 let (coef_, intercept_) = if self.config.fit_intercept {
465 let intercept = params[0];
466 let coef = params.slice(s![1..]).to_owned();
467 (coef, Some(intercept))
468 } else {
469 (params, None)
470 };
471
472 Ok(LinearRegression {
473 config: self.config,
474 state: PhantomData,
475 coef_: Some(coef_),
476 intercept_,
477 n_features_: Some(n_features),
478 })
479 }
480}
481
482impl LinearRegression<Trained> {
483 pub fn coef(&self) -> &Array1<Float> {
485 self.coef_.as_ref().expect("Model is trained")
486 }
487
488 pub fn intercept(&self) -> Option<Float> {
490 self.intercept_
491 }
492}
493
494impl Predict<Array2<Float>, Array1<Float>> for LinearRegression<Trained> {
495 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
496 let n_features = self.n_features_.expect("Model is trained");
497 validate::check_n_features(x, n_features)?;
498
499 let coef = self.coef_.as_ref().expect("Model is trained");
500 let mut predictions = x.dot(coef);
501
502 if let Some(intercept) = self.intercept_ {
503 predictions += intercept;
504 }
505
506 Ok(predictions)
507 }
508}
509
510impl Score<Array2<Float>, Array1<Float>> for LinearRegression<Trained> {
511 type Float = Float;
512
513 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
514 let predictions = self.predict(x)?;
515
516 let ss_res = (&predictions - y).mapv(|x| x * x).sum();
518 let y_mean = y.mean().unwrap_or(0.0);
519 let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
520
521 if ss_tot == 0.0 {
522 return Ok(1.0);
523 }
524
525 Ok(1.0 - (ss_res / ss_tot))
526 }
527}
528
529impl LinearRegression<Untrained> {
530 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
535 pub fn fit_with_early_stopping(
536 self,
537 x: &Array2<Float>,
538 y: &Array1<Float>,
539 early_stopping_config: EarlyStoppingConfig,
540 ) -> Result<(LinearRegression<Trained>, ValidationInfo)> {
541 validate::check_consistent_length(x, y)?;
543
544 let n_features = x.ncols();
545
546 match self.config.penalty {
548 Penalty::L1(alpha) => {
549 let cd_solver = CoordinateDescentSolver {
550 max_iter: self.config.max_iter,
551 tol: self.config.tol,
552 cyclic: true,
553 early_stopping_config: Some(early_stopping_config),
554 };
555
556 let (coef, intercept, validation_info) = cd_solver
557 .solve_lasso_with_early_stopping(x, y, alpha, self.config.fit_intercept)?;
558
559 let intercept_ = if self.config.fit_intercept {
560 intercept
561 } else {
562 None
563 };
564
565 let fitted_model = LinearRegression {
566 config: self.config,
567 state: PhantomData,
568 coef_: Some(coef),
569 intercept_,
570 n_features_: Some(n_features),
571 };
572
573 Ok((fitted_model, validation_info))
574 }
575 Penalty::ElasticNet { l1_ratio, alpha } => {
576 let cd_solver = CoordinateDescentSolver {
577 max_iter: self.config.max_iter,
578 tol: self.config.tol,
579 cyclic: true,
580 early_stopping_config: Some(early_stopping_config),
581 };
582
583 let (coef, intercept, validation_info) = cd_solver
584 .solve_elastic_net_with_early_stopping(
585 x,
586 y,
587 alpha,
588 l1_ratio,
589 self.config.fit_intercept,
590 )?;
591
592 let intercept_ = if self.config.fit_intercept {
593 intercept
594 } else {
595 None
596 };
597
598 let fitted_model = LinearRegression {
599 config: self.config,
600 state: PhantomData,
601 coef_: Some(coef),
602 intercept_,
603 n_features_: Some(n_features),
604 };
605
606 Ok((fitted_model, validation_info))
607 }
608 Penalty::L2(_alpha) => {
609 let fitted_model = self.fit(x, y)?;
612 let validation_info = ValidationInfo {
613 validation_scores: vec![1.0], best_score: Some(1.0),
615 best_iteration: 1,
616 stopped_early: false,
617 converged: true,
618 };
619 Ok((fitted_model, validation_info))
620 }
621 Penalty::None => {
622 let fitted_model = self.fit(x, y)?;
624 let validation_info = ValidationInfo {
625 validation_scores: vec![1.0], best_score: Some(1.0),
627 best_iteration: 1,
628 stopped_early: false,
629 converged: true,
630 };
631 Ok((fitted_model, validation_info))
632 }
633 }
634 }
635
636 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
641 pub fn fit_with_early_stopping_split(
642 self,
643 x_train: &Array2<Float>,
644 y_train: &Array1<Float>,
645 x_val: &Array2<Float>,
646 y_val: &Array1<Float>,
647 early_stopping_config: EarlyStoppingConfig,
648 ) -> Result<(LinearRegression<Trained>, ValidationInfo)> {
649 validate::check_consistent_length(x_train, y_train)?;
651 validate::check_consistent_length(x_val, y_val)?;
652
653 let n_features = x_train.ncols();
654 if x_val.ncols() != n_features {
655 return Err(SklearsError::FeatureMismatch {
656 expected: n_features,
657 actual: x_val.ncols(),
658 });
659 }
660
661 match self.config.penalty {
663 Penalty::L1(alpha) => {
664 let cd_solver = CoordinateDescentSolver {
665 max_iter: self.config.max_iter,
666 tol: self.config.tol,
667 cyclic: true,
668 early_stopping_config: Some(early_stopping_config),
669 };
670
671 let (coef, intercept, validation_info) = cd_solver
672 .solve_lasso_with_early_stopping_split(
673 x_train,
674 y_train,
675 x_val,
676 y_val,
677 alpha,
678 self.config.fit_intercept,
679 )?;
680
681 let intercept_ = if self.config.fit_intercept {
682 intercept
683 } else {
684 None
685 };
686
687 let fitted_model = LinearRegression {
688 config: self.config,
689 state: PhantomData,
690 coef_: Some(coef),
691 intercept_,
692 n_features_: Some(n_features),
693 };
694
695 Ok((fitted_model, validation_info))
696 }
697 Penalty::ElasticNet { l1_ratio, alpha } => {
698 let cd_solver = CoordinateDescentSolver {
699 max_iter: self.config.max_iter,
700 tol: self.config.tol,
701 cyclic: true,
702 early_stopping_config: Some(early_stopping_config),
703 };
704
705 let (coef, intercept, validation_info) = cd_solver
706 .solve_elastic_net_with_early_stopping_split(
707 x_train,
708 y_train,
709 x_val,
710 y_val,
711 alpha,
712 l1_ratio,
713 self.config.fit_intercept,
714 )?;
715
716 let intercept_ = if self.config.fit_intercept {
717 intercept
718 } else {
719 None
720 };
721
722 let fitted_model = LinearRegression {
723 config: self.config,
724 state: PhantomData,
725 coef_: Some(coef),
726 intercept_,
727 n_features_: Some(n_features),
728 };
729
730 Ok((fitted_model, validation_info))
731 }
732 Penalty::L2(_alpha) => {
733 let fitted_model = LinearRegression::new()
735 .penalty(self.config.penalty)
736 .fit_intercept(self.config.fit_intercept)
737 .fit(x_train, y_train)?;
738
739 let val_predictions = fitted_model.predict(x_val)?;
741 let r2_score = crate::coordinate_descent::compute_r2_score(&val_predictions, y_val);
742
743 let validation_info = ValidationInfo {
744 validation_scores: vec![r2_score],
745 best_score: Some(r2_score),
746 best_iteration: 1,
747 stopped_early: false,
748 converged: true,
749 };
750
751 Ok((fitted_model, validation_info))
752 }
753 Penalty::None => {
754 let fitted_model = LinearRegression::new()
756 .fit_intercept(self.config.fit_intercept)
757 .fit(x_train, y_train)?;
758
759 let val_predictions = fitted_model.predict(x_val)?;
761 let r2_score = crate::coordinate_descent::compute_r2_score(&val_predictions, y_val);
762
763 let validation_info = ValidationInfo {
764 validation_scores: vec![r2_score],
765 best_score: Some(r2_score),
766 best_iteration: 1,
767 stopped_early: false,
768 converged: true,
769 };
770
771 Ok((fitted_model, validation_info))
772 }
773 }
774 }
775}
776
777#[allow(non_snake_case)]
778#[cfg(test)]
779mod tests {
780 use super::*;
781 use approx::assert_abs_diff_eq;
782 use scirs2_core::ndarray::array;
783
784 #[test]
785 fn test_linear_regression_simple() {
786 let x = array![[1.0], [2.0], [3.0], [4.0]];
787 let y = array![2.0, 4.0, 6.0, 8.0];
788
789 let model = LinearRegression::new()
790 .fit_intercept(false)
791 .fit(&x, &y)
792 .unwrap();
793
794 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-10);
795
796 let predictions = model.predict(&array![[5.0]]).unwrap();
797 assert_abs_diff_eq!(predictions[0], 10.0, epsilon = 1e-10);
798 }
799
800 #[test]
801 fn test_linear_regression_with_intercept() {
802 let x = array![[1.0], [2.0], [3.0], [4.0]];
803 let y = array![3.0, 5.0, 7.0, 9.0]; let model = LinearRegression::new()
806 .fit_intercept(true)
807 .fit(&x, &y)
808 .unwrap();
809
810 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-10);
811 assert_abs_diff_eq!(model.intercept().unwrap(), 1.0, epsilon = 1e-10);
812 }
813
814 #[test]
815 fn test_ridge_regression() {
816 let x = array![[1.0], [2.0], [3.0], [4.0]];
817 let y = array![2.0, 4.0, 6.0, 8.0];
818
819 let model = LinearRegression::new()
820 .fit_intercept(false)
821 .regularization(0.1)
822 .fit(&x, &y)
823 .unwrap();
824
825 assert!(model.coef()[0] < 2.0);
827 assert!(model.coef()[0] > 1.9);
828 }
829
830 #[test]
831 fn test_lasso_regression() {
832 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
833 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
834
835 let model = LinearRegression::lasso(0.01)
837 .fit_intercept(false)
838 .fit(&x, &y)
839 .unwrap();
840
841 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 0.1);
843
844 let model = LinearRegression::lasso(0.5)
846 .fit_intercept(false)
847 .fit(&x, &y)
848 .unwrap();
849
850 assert!(model.coef()[0] < 2.0);
852 assert!(model.coef()[0] > 1.0);
853 }
854
855 #[test]
856 fn test_elastic_net_regression() {
857 let x = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
858 let y = array![3.0, 6.0, 9.0, 12.0]; let model = LinearRegression::elastic_net(0.1, 0.5)
861 .fit_intercept(false)
862 .fit(&x, &y)
863 .unwrap();
864
865 println!(
867 "ElasticNet coef[0] = {}, coef[1] = {}",
868 model.coef()[0],
869 model.coef()[1]
870 );
871 assert!(model.coef()[0] > 0.0);
872 assert!(model.coef()[0] < 3.0); assert!(model.coef()[1] > 0.0);
874 assert!(model.coef()[1] < 3.0); }
876
877 #[test]
878 fn test_lasso_sparsity() {
879 let n_samples = 20;
881 let mut x = Array2::zeros((n_samples, 5));
882 let mut y = Array1::zeros(n_samples);
883
884 for i in 0..n_samples {
885 x[[i, 0]] = i as f64;
886 x[[i, 1]] = (i as f64) * 0.1; x[[i, 2]] = ((i * 7) % 10) as f64 / 10.0; x[[i, 3]] = ((i * 13) % 10) as f64 / 10.0; x[[i, 4]] = ((i * 17) % 10) as f64 / 10.0; y[i] = 2.0 * x[[i, 0]] + 0.05 * (i % 3) as f64;
892 }
893
894 let model = LinearRegression::lasso(1.0)
896 .fit_intercept(false)
897 .fit(&x, &y)
898 .unwrap();
899
900 let coef = model.coef();
901
902 assert!(coef[0] > 0.5);
904
905 for i in 2..5 {
907 assert_abs_diff_eq!(coef[i], 0.0, epsilon = 0.01);
908 }
909 }
910
911 #[test]
912 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
913 fn test_linear_regression_early_stopping_lasso() {
914 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
915
916 let n_samples = 100;
918 let n_features = 8;
919 let mut x = Array2::zeros((n_samples, n_features));
920 let mut y = Array1::zeros(n_samples);
921
922 for i in 0..n_samples {
924 for j in 0..n_features {
925 x[[i, j]] = (i * j + 1) as f64 / 20.0;
926 }
927 y[i] = 2.0 * x[[i, 0]] + 1.5 * x[[i, 1]] + 0.8 * x[[i, 2]] + 0.1 * (i as f64 % 5.0);
929 }
930
931 let early_stopping_config = EarlyStoppingConfig {
932 criterion: StoppingCriterion::Patience(10),
933 validation_split: 0.25,
934 shuffle: true,
935 random_state: Some(42),
936 higher_is_better: true,
937 min_iterations: 5,
938 restore_best_weights: true,
939 };
940
941 let model = LinearRegression::lasso(0.1);
942 let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
943
944 assert!(result.is_ok());
945 let (fitted_model, validation_info) = result.unwrap();
946
947 assert_eq!(fitted_model.coef().len(), n_features);
949 assert!(fitted_model.intercept().is_some());
950
951 assert!(!validation_info.validation_scores.is_empty());
953 assert!(validation_info.best_score.is_some());
954 assert!(validation_info.best_iteration >= 1);
955
956 let predictions = fitted_model.predict(&x);
958 assert!(predictions.is_ok());
959 assert_eq!(predictions.unwrap().len(), n_samples);
960 }
961
962 #[test]
963 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
964 fn test_linear_regression_early_stopping_elastic_net() {
965 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
966
967 let x = array![
968 [1.0, 2.0, 0.5],
969 [2.0, 3.0, 1.0],
970 [3.0, 4.0, 1.5],
971 [4.0, 5.0, 2.0],
972 [5.0, 6.0, 2.5],
973 [6.0, 7.0, 3.0],
974 [7.0, 8.0, 3.5],
975 [8.0, 9.0, 4.0]
976 ];
977 let y = array![4.5, 7.0, 9.5, 12.0, 14.5, 17.0, 19.5, 22.0]; let early_stopping_config = EarlyStoppingConfig {
980 criterion: StoppingCriterion::TolerancePatience {
981 tolerance: 0.005,
982 patience: 3,
983 },
984 validation_split: 0.25,
985 shuffle: false,
986 random_state: Some(123),
987 higher_is_better: true,
988 min_iterations: 2,
989 restore_best_weights: true,
990 };
991
992 let model = LinearRegression::elastic_net(0.1, 0.7);
993 let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
994
995 assert!(result.is_ok());
996 let (fitted_model, validation_info) = result.unwrap();
997
998 assert_eq!(fitted_model.coef().len(), 3);
999 assert!(fitted_model.intercept().is_some());
1000 assert!(!validation_info.validation_scores.is_empty());
1001 assert!(validation_info.best_score.is_some());
1002 }
1003
1004 #[test]
1005 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1006 fn test_linear_regression_early_stopping_with_split() {
1007 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1008
1009 let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1011 let y_train = array![5.0, 8.0, 11.0, 14.0, 17.0]; let x_val = array![[6.0, 7.0], [7.0, 8.0]];
1015 let y_val = array![20.0, 23.0];
1016
1017 let early_stopping_config = EarlyStoppingConfig {
1018 criterion: StoppingCriterion::TargetScore(0.9),
1019 validation_split: 0.2, shuffle: false,
1021 random_state: None,
1022 higher_is_better: true,
1023 min_iterations: 1,
1024 restore_best_weights: false,
1025 };
1026
1027 let model = LinearRegression::lasso(0.01);
1028 let result = model.fit_with_early_stopping_split(
1029 &x_train,
1030 &y_train,
1031 &x_val,
1032 &y_val,
1033 early_stopping_config,
1034 );
1035
1036 assert!(result.is_ok());
1037 let (fitted_model, validation_info) = result.unwrap();
1038
1039 assert_eq!(fitted_model.coef().len(), 2);
1040 assert!(fitted_model.intercept().is_some());
1041 assert!(!validation_info.validation_scores.is_empty());
1042
1043 let coef = fitted_model.coef();
1045 assert!((coef[0] - 2.0).abs() < 0.5);
1046 assert!((coef[1] - 1.0).abs() < 0.5);
1047 }
1048
1049 #[test]
1050 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1051 fn test_linear_regression_early_stopping_ols() {
1052 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1053
1054 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1055 let y = array![3.0, 5.0, 7.0, 9.0, 11.0, 13.0]; let early_stopping_config = EarlyStoppingConfig {
1058 criterion: StoppingCriterion::Patience(5),
1059 validation_split: 0.33,
1060 shuffle: false,
1061 random_state: None,
1062 higher_is_better: true,
1063 min_iterations: 1,
1064 restore_best_weights: true,
1065 };
1066
1067 let model = LinearRegression::new().fit_intercept(true);
1069 let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1070
1071 assert!(result.is_ok());
1072 let (fitted_model, validation_info) = result.unwrap();
1073
1074 assert_eq!(fitted_model.coef().len(), 1);
1075 assert!(fitted_model.intercept().is_some());
1076
1077 assert!(!validation_info.stopped_early);
1079 assert!(validation_info.converged);
1080 assert_eq!(validation_info.best_iteration, 1);
1081
1082 assert_abs_diff_eq!(fitted_model.coef()[0], 2.0, epsilon = 1e-10);
1084 assert_abs_diff_eq!(fitted_model.intercept().unwrap(), 1.0, epsilon = 1e-10);
1085 }
1086
1087 #[test]
1088 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1089 fn test_linear_regression_early_stopping_ridge() {
1090 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1091
1092 let x = array![
1093 [1.0, 0.5],
1094 [2.0, 1.0],
1095 [3.0, 1.5],
1096 [4.0, 2.0],
1097 [5.0, 2.5],
1098 [6.0, 3.0]
1099 ];
1100 let y = array![2.5, 4.0, 5.5, 7.0, 8.5, 10.0]; let early_stopping_config = EarlyStoppingConfig {
1103 criterion: StoppingCriterion::Patience(3),
1104 validation_split: 0.33,
1105 shuffle: true,
1106 random_state: Some(456),
1107 higher_is_better: true,
1108 min_iterations: 1,
1109 restore_best_weights: false,
1110 };
1111
1112 let model = LinearRegression::new()
1114 .regularization(0.1)
1115 .fit_intercept(true);
1116 let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1117
1118 assert!(result.is_ok());
1119 let (fitted_model, validation_info) = result.unwrap();
1120
1121 assert_eq!(fitted_model.coef().len(), 2);
1122 assert!(fitted_model.intercept().is_some());
1123
1124 assert!(!validation_info.stopped_early);
1126 assert!(validation_info.converged);
1127 }
1128}