sklears_linear/
linear_regression.rs

1//! Linear Regression implementation
2
3use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{s, Array};
6use scirs2_linalg::compat::ArrayLinalgExt;
7// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
8use sklears_core::{
9    error::{validate, Result, SklearsError},
10    traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11    types::{Array1, Array2, Float},
12};
13
14use crate::{Penalty, Solver};
15
16#[cfg(feature = "coordinate-descent")]
17use crate::coordinate_descent::CoordinateDescentSolver;
18
19#[cfg(feature = "coordinate-descent")]
20use crate::coordinate_descent::ValidationInfo;
21
22#[cfg(feature = "early-stopping")]
23use crate::early_stopping::EarlyStoppingConfig;
24
25/// Configuration for Linear Regression
26#[derive(Debug, Clone)]
27pub struct LinearRegressionConfig {
28    /// Whether to fit the intercept
29    pub fit_intercept: bool,
30    /// Regularization penalty
31    pub penalty: Penalty,
32    /// Solver to use
33    pub solver: Solver,
34    /// Maximum iterations for iterative solvers
35    pub max_iter: usize,
36    /// Tolerance for convergence
37    pub tol: f64,
38    /// Whether to use warm start (reuse previous solution as initialization)
39    pub warm_start: bool,
40    /// Enable GPU acceleration if available
41    #[cfg(feature = "gpu")]
42    pub use_gpu: bool,
43    /// Minimum problem size to use GPU acceleration
44    #[cfg(feature = "gpu")]
45    pub gpu_min_size: usize,
46}
47
48impl Default for LinearRegressionConfig {
49    fn default() -> Self {
50        Self {
51            fit_intercept: true,
52            penalty: Penalty::None,
53            solver: Solver::Auto,
54            max_iter: 1000,
55            tol: 1e-4,
56            warm_start: false,
57            #[cfg(feature = "gpu")]
58            use_gpu: true,
59            #[cfg(feature = "gpu")]
60            gpu_min_size: 1000,
61        }
62    }
63}
64
65/// Linear Regression model
66#[derive(Debug, Clone)]
67pub struct LinearRegression<State = Untrained> {
68    config: LinearRegressionConfig,
69    state: PhantomData<State>,
70    // Trained state fields
71    coef_: Option<Array1<Float>>,
72    intercept_: Option<Float>,
73    n_features_: Option<usize>,
74}
75
76impl LinearRegression<Untrained> {
77    /// Create a new Linear Regression model
78    pub fn new() -> Self {
79        Self {
80            config: LinearRegressionConfig::default(),
81            state: PhantomData,
82            coef_: None,
83            intercept_: None,
84            n_features_: None,
85        }
86    }
87
88    /// Set whether to fit intercept
89    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
90        self.config.fit_intercept = fit_intercept;
91        self
92    }
93
94    /// Set regularization (Ridge/L2)
95    pub fn regularization(mut self, alpha: f64) -> Self {
96        self.config.penalty = Penalty::L2(alpha);
97        self
98    }
99
100    /// Create a Lasso regression model (L1 penalty)
101    pub fn lasso(alpha: f64) -> Self {
102        Self::new()
103            .penalty(Penalty::L1(alpha))
104            .solver(Solver::CoordinateDescent)
105    }
106
107    /// Create an ElasticNet regression model (L1 + L2 penalty)
108    pub fn elastic_net(alpha: f64, l1_ratio: f64) -> Self {
109        Self::new()
110            .penalty(Penalty::ElasticNet { l1_ratio, alpha })
111            .solver(Solver::CoordinateDescent)
112    }
113
114    /// Set penalty
115    pub fn penalty(mut self, penalty: Penalty) -> Self {
116        self.config.penalty = penalty;
117        self
118    }
119
120    /// Set solver
121    pub fn solver(mut self, solver: Solver) -> Self {
122        self.config.solver = solver;
123        self
124    }
125
126    /// Set maximum iterations
127    pub fn max_iter(mut self, max_iter: usize) -> Self {
128        self.config.max_iter = max_iter;
129        self
130    }
131
132    /// Set whether to use warm start
133    pub fn warm_start(mut self, warm_start: bool) -> Self {
134        self.config.warm_start = warm_start;
135        self
136    }
137
138    /// Enable or disable GPU acceleration
139    #[cfg(feature = "gpu")]
140    pub fn use_gpu(mut self, use_gpu: bool) -> Self {
141        self.config.use_gpu = use_gpu;
142        self
143    }
144
145    /// Set minimum problem size for GPU acceleration
146    #[cfg(feature = "gpu")]
147    pub fn gpu_min_size(mut self, min_size: usize) -> Self {
148        self.config.gpu_min_size = min_size;
149        self
150    }
151}
152
153impl Default for LinearRegression<Untrained> {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl Estimator for LinearRegression<Untrained> {
160    type Config = LinearRegressionConfig;
161    type Error = SklearsError;
162    type Float = Float;
163
164    fn config(&self) -> &Self::Config {
165        &self.config
166    }
167}
168
169impl Fit<Array2<Float>, Array1<Float>> for LinearRegression<Untrained> {
170    type Fitted = LinearRegression<Trained>;
171
172    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
173        // Validate inputs
174        validate::check_consistent_length(x, y)?;
175
176        let n_samples = x.nrows();
177        let n_features = x.ncols();
178
179        // Add intercept column if needed
180        let (x_with_intercept, n_params) = if self.config.fit_intercept {
181            let mut x_new = Array::ones((n_samples, n_features + 1));
182            x_new.slice_mut(s![.., 1..]).assign(x);
183            (x_new, n_features + 1)
184        } else {
185            (x.clone(), n_features)
186        };
187
188        // Solve based on penalty type
189        let params = match self.config.penalty {
190            Penalty::None => {
191                // Check if we should use GPU acceleration
192                #[cfg(feature = "gpu")]
193                if self.config.use_gpu && n_samples * n_features >= self.config.gpu_min_size {
194                    // Try GPU-accelerated OLS
195                    match self.solve_ols_gpu(&x_with_intercept, y) {
196                        Ok(params) => params,
197                        Err(_) => {
198                            // Fallback to CPU if GPU fails
199                            self.solve_ols_cpu(&x_with_intercept, y)?
200                        }
201                    }
202                } else {
203                    self.solve_ols_cpu(&x_with_intercept, y)?
204                }
205
206                #[cfg(not(feature = "gpu"))]
207                self.solve_ols_cpu(&x_with_intercept, y)?
208            }
209            Penalty::L2(alpha) => {
210                // Ridge regression
211                // (X^T X + αI) β = X^T y
212                let xtx = x_with_intercept.t().dot(&x_with_intercept);
213                let xty = x_with_intercept.t().dot(y);
214
215                // Add regularization to diagonal (except intercept if present)
216                let mut regularized = xtx.clone();
217                let start_idx = if self.config.fit_intercept { 1 } else { 0 };
218                for i in start_idx..n_params {
219                    regularized[[i, i]] += alpha;
220                }
221
222                regularized.solve(&xty).map_err(|e| {
223                    SklearsError::NumericalError(format!("Failed to solve ridge regression: {}", e))
224                })?
225            }
226            Penalty::L1(alpha) => {
227                // Lasso regression using coordinate descent
228                #[cfg(feature = "coordinate-descent")]
229                {
230                    let cd_solver = CoordinateDescentSolver {
231                        max_iter: self.config.max_iter,
232                        tol: self.config.tol,
233                        cyclic: true,
234                        #[cfg(feature = "early-stopping")]
235                        early_stopping_config: None,
236                    };
237
238                    let (coef, intercept) = cd_solver
239                        .solve_lasso(x, y, alpha, self.config.fit_intercept)
240                        .map_err(|e| {
241                            SklearsError::NumericalError(format!(
242                                "Coordinate descent failed: {}",
243                                e
244                            ))
245                        })?;
246
247                    if self.config.fit_intercept {
248                        // Need to add intercept to beginning of params for consistency
249                        let mut params = Array::zeros(coef.len() + 1);
250                        params[0] = intercept.unwrap_or(0.0);
251                        params.slice_mut(s![1..]).assign(&coef);
252                        params
253                    } else {
254                        coef
255                    }
256                }
257                #[cfg(not(feature = "coordinate-descent"))]
258                {
259                    return Err(SklearsError::InvalidParameter {
260                        name: "penalty".to_string(),
261                        reason:
262                            "L1 regularization (Lasso) requires the 'coordinate-descent' feature"
263                                .to_string(),
264                    });
265                }
266            }
267            Penalty::ElasticNet { l1_ratio, alpha } => {
268                // ElasticNet regression using coordinate descent
269                #[cfg(feature = "coordinate-descent")]
270                {
271                    let cd_solver = CoordinateDescentSolver {
272                        max_iter: self.config.max_iter,
273                        tol: self.config.tol,
274                        cyclic: true,
275                        #[cfg(feature = "early-stopping")]
276                        early_stopping_config: None,
277                    };
278
279                    let (coef, intercept) = cd_solver
280                        .solve_elastic_net(x, y, alpha, l1_ratio, self.config.fit_intercept)
281                        .map_err(|e| {
282                            SklearsError::NumericalError(format!(
283                                "Coordinate descent failed: {}",
284                                e
285                            ))
286                        })?;
287
288                    if self.config.fit_intercept {
289                        // Need to add intercept to beginning of params for consistency
290                        let mut params = Array::zeros(coef.len() + 1);
291                        params[0] = intercept.unwrap_or(0.0);
292                        params.slice_mut(s![1..]).assign(&coef);
293                        params
294                    } else {
295                        coef
296                    }
297                }
298                #[cfg(not(feature = "coordinate-descent"))]
299                {
300                    return Err(SklearsError::InvalidParameter {
301                        name: "penalty".to_string(),
302                        reason:
303                            "ElasticNet regularization requires the 'coordinate-descent' feature"
304                                .to_string(),
305                    });
306                }
307            }
308        };
309
310        // Extract coefficients and intercept
311        let (coef_, intercept_) = if self.config.fit_intercept {
312            let intercept = params[0];
313            let coef = params.slice(s![1..]).to_owned();
314            (coef, Some(intercept))
315        } else {
316            (params, None)
317        };
318
319        Ok(LinearRegression {
320            config: self.config,
321            state: PhantomData,
322            coef_: Some(coef_),
323            intercept_,
324            n_features_: Some(n_features),
325        })
326    }
327}
328
329impl LinearRegression<Untrained> {
330    /// CPU-based OLS solver
331    fn solve_ols_cpu(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
332        // Ordinary Least Squares using scirs2
333        // X^T X β = X^T y
334        let xtx = x.t().dot(x);
335        let xty = x.t().dot(y);
336
337        // Use scirs2's linear solver
338        xtx.solve(&xty).map_err(|e| {
339            SklearsError::NumericalError(format!("Failed to solve linear system: {}", e))
340        })
341    }
342
343    /// GPU-based OLS solver
344    #[cfg(feature = "gpu")]
345    fn solve_ols_gpu(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
346        use crate::gpu_acceleration::{GpuConfig, GpuLinearOps};
347
348        // Initialize GPU operations
349        let gpu_config = GpuConfig {
350            device_id: 0,
351            use_pinned_memory: true,
352            min_problem_size: self.config.gpu_min_size,
353            ..Default::default()
354        };
355
356        let gpu_ops = GpuLinearOps::new(gpu_config).map_err(|e| {
357            SklearsError::NumericalError(format!("Failed to initialize GPU operations: {}", e))
358        })?;
359
360        // Check if GPU is available
361        if !gpu_ops.is_gpu_available() {
362            return Err(SklearsError::NumericalError(
363                "GPU not available, falling back to CPU".to_string(),
364            ));
365        }
366
367        // Compute X^T X using GPU
368        let xt = gpu_ops.matrix_transpose(x)?;
369        let xtx = gpu_ops.matrix_multiply(&xt, x)?;
370
371        // Compute X^T y using GPU
372        let xty = gpu_ops.matrix_vector_multiply(&xt, y)?;
373
374        // Solve linear system using GPU
375        gpu_ops.solve_linear_system(&xtx, &xty)
376    }
377
378    /// Fit the linear regression model with warm start
379    ///
380    /// Uses the provided coefficients and intercept as initialization for iterative solvers
381    pub fn fit_with_warm_start(
382        self,
383        x: &Array2<Float>,
384        y: &Array1<Float>,
385        initial_coef: Option<&Array1<Float>>,
386        initial_intercept: Option<Float>,
387    ) -> Result<LinearRegression<Trained>> {
388        // Validate inputs
389        validate::check_consistent_length(x, y)?;
390
391        let n_features = x.ncols();
392
393        // For warm start, we only support ElasticNet/Lasso methods (coordinate descent)
394        let params: Array1<Float> = match self.config.penalty {
395            Penalty::L1(_)
396            | Penalty::L2(_)
397            | Penalty::ElasticNet {
398                alpha: _,
399                l1_ratio: _,
400            } => {
401                #[cfg(feature = "coordinate-descent")]
402                {
403                    let (alpha_val, l1_ratio) = match self.config.penalty {
404                        Penalty::L1(alpha) => (alpha, 1.0),
405                        Penalty::L2(alpha) => (alpha, 0.0),
406                        Penalty::ElasticNet { alpha, l1_ratio } => (alpha, l1_ratio),
407                        _ => unreachable!(),
408                    };
409
410                    let cd_solver = CoordinateDescentSolver {
411                        max_iter: self.config.max_iter,
412                        tol: self.config.tol,
413                        cyclic: true,
414                        #[cfg(feature = "early-stopping")]
415                        early_stopping_config: None,
416                    };
417
418                    let (coef, intercept) = cd_solver
419                        .solve_elastic_net_with_warm_start(
420                            x,
421                            y,
422                            alpha_val,
423                            l1_ratio,
424                            self.config.fit_intercept,
425                            initial_coef,
426                            initial_intercept,
427                        )
428                        .map_err(|e| {
429                            SklearsError::NumericalError(format!(
430                                "Coordinate descent failed: {}",
431                                e
432                            ))
433                        })?;
434
435                    if self.config.fit_intercept {
436                        // Need to add intercept to beginning of params for consistency
437                        let mut params = Array::zeros(coef.len() + 1);
438                        params[0] = intercept.unwrap_or(0.0);
439                        params.slice_mut(s![1..]).assign(&coef);
440                        params
441                    } else {
442                        coef
443                    }
444                }
445                #[cfg(not(feature = "coordinate-descent"))]
446                {
447                    return Err(SklearsError::InvalidParameter {
448                        name: "penalty".to_string(),
449                        reason: "Warm start requires the 'coordinate-descent' feature".to_string(),
450                    });
451                }
452            }
453            Penalty::None => {
454                return Err(SklearsError::InvalidParameter {
455                    name: "penalty".to_string(),
456                    reason:
457                        "Warm start only supported for regularized methods (L1, L2, ElasticNet)"
458                            .to_string(),
459                });
460            }
461        };
462
463        // Extract coefficients and intercept
464        let (coef_, intercept_) = if self.config.fit_intercept {
465            let intercept = params[0];
466            let coef = params.slice(s![1..]).to_owned();
467            (coef, Some(intercept))
468        } else {
469            (params, None)
470        };
471
472        Ok(LinearRegression {
473            config: self.config,
474            state: PhantomData,
475            coef_: Some(coef_),
476            intercept_,
477            n_features_: Some(n_features),
478        })
479    }
480}
481
482impl LinearRegression<Trained> {
483    /// Get the coefficients
484    pub fn coef(&self) -> &Array1<Float> {
485        self.coef_.as_ref().expect("Model is trained")
486    }
487
488    /// Get the intercept
489    pub fn intercept(&self) -> Option<Float> {
490        self.intercept_
491    }
492}
493
494impl Predict<Array2<Float>, Array1<Float>> for LinearRegression<Trained> {
495    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
496        let n_features = self.n_features_.expect("Model is trained");
497        validate::check_n_features(x, n_features)?;
498
499        let coef = self.coef_.as_ref().expect("Model is trained");
500        let mut predictions = x.dot(coef);
501
502        if let Some(intercept) = self.intercept_ {
503            predictions += intercept;
504        }
505
506        Ok(predictions)
507    }
508}
509
510impl Score<Array2<Float>, Array1<Float>> for LinearRegression<Trained> {
511    type Float = Float;
512
513    fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
514        let predictions = self.predict(x)?;
515
516        // Calculate R² score using scirs2 metrics
517        let ss_res = (&predictions - y).mapv(|x| x * x).sum();
518        let y_mean = y.mean().unwrap_or(0.0);
519        let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
520
521        if ss_tot == 0.0 {
522            return Ok(1.0);
523        }
524
525        Ok(1.0 - (ss_res / ss_tot))
526    }
527}
528
529impl LinearRegression<Untrained> {
530    /// Fit the linear regression model with early stopping based on validation metrics
531    ///
532    /// This method is particularly useful for regularized methods (Lasso, ElasticNet)
533    /// where early stopping can prevent overfitting.
534    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
535    pub fn fit_with_early_stopping(
536        self,
537        x: &Array2<Float>,
538        y: &Array1<Float>,
539        early_stopping_config: EarlyStoppingConfig,
540    ) -> Result<(LinearRegression<Trained>, ValidationInfo)> {
541        // Validate inputs
542        validate::check_consistent_length(x, y)?;
543
544        let n_features = x.ncols();
545
546        // Early stopping is most beneficial for regularized methods
547        match self.config.penalty {
548            Penalty::L1(alpha) => {
549                let cd_solver = CoordinateDescentSolver {
550                    max_iter: self.config.max_iter,
551                    tol: self.config.tol,
552                    cyclic: true,
553                    early_stopping_config: Some(early_stopping_config),
554                };
555
556                let (coef, intercept, validation_info) = cd_solver
557                    .solve_lasso_with_early_stopping(x, y, alpha, self.config.fit_intercept)?;
558
559                let intercept_ = if self.config.fit_intercept {
560                    intercept
561                } else {
562                    None
563                };
564
565                let fitted_model = LinearRegression {
566                    config: self.config,
567                    state: PhantomData,
568                    coef_: Some(coef),
569                    intercept_,
570                    n_features_: Some(n_features),
571                };
572
573                Ok((fitted_model, validation_info))
574            }
575            Penalty::ElasticNet { l1_ratio, alpha } => {
576                let cd_solver = CoordinateDescentSolver {
577                    max_iter: self.config.max_iter,
578                    tol: self.config.tol,
579                    cyclic: true,
580                    early_stopping_config: Some(early_stopping_config),
581                };
582
583                let (coef, intercept, validation_info) = cd_solver
584                    .solve_elastic_net_with_early_stopping(
585                        x,
586                        y,
587                        alpha,
588                        l1_ratio,
589                        self.config.fit_intercept,
590                    )?;
591
592                let intercept_ = if self.config.fit_intercept {
593                    intercept
594                } else {
595                    None
596                };
597
598                let fitted_model = LinearRegression {
599                    config: self.config,
600                    state: PhantomData,
601                    coef_: Some(coef),
602                    intercept_,
603                    n_features_: Some(n_features),
604                };
605
606                Ok((fitted_model, validation_info))
607            }
608            Penalty::L2(_alpha) => {
609                // For Ridge regression, we can use iterative solver with early stopping
610                // For now, fall back to regular fit and provide minimal validation info
611                let fitted_model = self.fit(x, y)?;
612                let validation_info = ValidationInfo {
613                    validation_scores: vec![1.0], // Dummy score
614                    best_score: Some(1.0),
615                    best_iteration: 1,
616                    stopped_early: false,
617                    converged: true,
618                };
619                Ok((fitted_model, validation_info))
620            }
621            Penalty::None => {
622                // For OLS, early stopping doesn't make much sense since it's a direct solution
623                let fitted_model = self.fit(x, y)?;
624                let validation_info = ValidationInfo {
625                    validation_scores: vec![1.0], // Dummy score
626                    best_score: Some(1.0),
627                    best_iteration: 1,
628                    stopped_early: false,
629                    converged: true,
630                };
631                Ok((fitted_model, validation_info))
632            }
633        }
634    }
635
636    /// Fit the linear regression model with early stopping using pre-split validation data
637    ///
638    /// This gives you more control over the train/validation split compared to
639    /// `fit_with_early_stopping` which automatically splits the data.
640    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
641    pub fn fit_with_early_stopping_split(
642        self,
643        x_train: &Array2<Float>,
644        y_train: &Array1<Float>,
645        x_val: &Array2<Float>,
646        y_val: &Array1<Float>,
647        early_stopping_config: EarlyStoppingConfig,
648    ) -> Result<(LinearRegression<Trained>, ValidationInfo)> {
649        // Validate inputs
650        validate::check_consistent_length(x_train, y_train)?;
651        validate::check_consistent_length(x_val, y_val)?;
652
653        let n_features = x_train.ncols();
654        if x_val.ncols() != n_features {
655            return Err(SklearsError::FeatureMismatch {
656                expected: n_features,
657                actual: x_val.ncols(),
658            });
659        }
660
661        // Early stopping is most beneficial for regularized methods
662        match self.config.penalty {
663            Penalty::L1(alpha) => {
664                let cd_solver = CoordinateDescentSolver {
665                    max_iter: self.config.max_iter,
666                    tol: self.config.tol,
667                    cyclic: true,
668                    early_stopping_config: Some(early_stopping_config),
669                };
670
671                let (coef, intercept, validation_info) = cd_solver
672                    .solve_lasso_with_early_stopping_split(
673                        x_train,
674                        y_train,
675                        x_val,
676                        y_val,
677                        alpha,
678                        self.config.fit_intercept,
679                    )?;
680
681                let intercept_ = if self.config.fit_intercept {
682                    intercept
683                } else {
684                    None
685                };
686
687                let fitted_model = LinearRegression {
688                    config: self.config,
689                    state: PhantomData,
690                    coef_: Some(coef),
691                    intercept_,
692                    n_features_: Some(n_features),
693                };
694
695                Ok((fitted_model, validation_info))
696            }
697            Penalty::ElasticNet { l1_ratio, alpha } => {
698                let cd_solver = CoordinateDescentSolver {
699                    max_iter: self.config.max_iter,
700                    tol: self.config.tol,
701                    cyclic: true,
702                    early_stopping_config: Some(early_stopping_config),
703                };
704
705                let (coef, intercept, validation_info) = cd_solver
706                    .solve_elastic_net_with_early_stopping_split(
707                        x_train,
708                        y_train,
709                        x_val,
710                        y_val,
711                        alpha,
712                        l1_ratio,
713                        self.config.fit_intercept,
714                    )?;
715
716                let intercept_ = if self.config.fit_intercept {
717                    intercept
718                } else {
719                    None
720                };
721
722                let fitted_model = LinearRegression {
723                    config: self.config,
724                    state: PhantomData,
725                    coef_: Some(coef),
726                    intercept_,
727                    n_features_: Some(n_features),
728                };
729
730                Ok((fitted_model, validation_info))
731            }
732            Penalty::L2(_alpha) => {
733                // For Ridge regression, compute validation score manually
734                let fitted_model = LinearRegression::new()
735                    .penalty(self.config.penalty)
736                    .fit_intercept(self.config.fit_intercept)
737                    .fit(x_train, y_train)?;
738
739                // Compute validation R² score
740                let val_predictions = fitted_model.predict(x_val)?;
741                let r2_score = crate::coordinate_descent::compute_r2_score(&val_predictions, y_val);
742
743                let validation_info = ValidationInfo {
744                    validation_scores: vec![r2_score],
745                    best_score: Some(r2_score),
746                    best_iteration: 1,
747                    stopped_early: false,
748                    converged: true,
749                };
750
751                Ok((fitted_model, validation_info))
752            }
753            Penalty::None => {
754                // For OLS, compute validation score manually
755                let fitted_model = LinearRegression::new()
756                    .fit_intercept(self.config.fit_intercept)
757                    .fit(x_train, y_train)?;
758
759                // Compute validation R² score
760                let val_predictions = fitted_model.predict(x_val)?;
761                let r2_score = crate::coordinate_descent::compute_r2_score(&val_predictions, y_val);
762
763                let validation_info = ValidationInfo {
764                    validation_scores: vec![r2_score],
765                    best_score: Some(r2_score),
766                    best_iteration: 1,
767                    stopped_early: false,
768                    converged: true,
769                };
770
771                Ok((fitted_model, validation_info))
772            }
773        }
774    }
775}
776
777#[allow(non_snake_case)]
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use approx::assert_abs_diff_eq;
782    use scirs2_core::ndarray::array;
783
784    #[test]
785    fn test_linear_regression_simple() {
786        let x = array![[1.0], [2.0], [3.0], [4.0]];
787        let y = array![2.0, 4.0, 6.0, 8.0];
788
789        let model = LinearRegression::new()
790            .fit_intercept(false)
791            .fit(&x, &y)
792            .unwrap();
793
794        assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-10);
795
796        let predictions = model.predict(&array![[5.0]]).unwrap();
797        assert_abs_diff_eq!(predictions[0], 10.0, epsilon = 1e-10);
798    }
799
800    #[test]
801    fn test_linear_regression_with_intercept() {
802        let x = array![[1.0], [2.0], [3.0], [4.0]];
803        let y = array![3.0, 5.0, 7.0, 9.0]; // y = 2x + 1
804
805        let model = LinearRegression::new()
806            .fit_intercept(true)
807            .fit(&x, &y)
808            .unwrap();
809
810        assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-10);
811        assert_abs_diff_eq!(model.intercept().unwrap(), 1.0, epsilon = 1e-10);
812    }
813
814    #[test]
815    fn test_ridge_regression() {
816        let x = array![[1.0], [2.0], [3.0], [4.0]];
817        let y = array![2.0, 4.0, 6.0, 8.0];
818
819        let model = LinearRegression::new()
820            .fit_intercept(false)
821            .regularization(0.1)
822            .fit(&x, &y)
823            .unwrap();
824
825        // With regularization, coefficient should be slightly less than 2.0
826        assert!(model.coef()[0] < 2.0);
827        assert!(model.coef()[0] > 1.9);
828    }
829
830    #[test]
831    fn test_lasso_regression() {
832        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
833        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
834
835        // Test with small alpha
836        let model = LinearRegression::lasso(0.01)
837            .fit_intercept(false)
838            .fit(&x, &y)
839            .unwrap();
840
841        // Should be close to OLS solution (coef = 2.0)
842        assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 0.1);
843
844        // Test with larger alpha
845        let model = LinearRegression::lasso(0.5)
846            .fit_intercept(false)
847            .fit(&x, &y)
848            .unwrap();
849
850        // Coefficient should be shrunk
851        assert!(model.coef()[0] < 2.0);
852        assert!(model.coef()[0] > 1.0);
853    }
854
855    #[test]
856    fn test_elastic_net_regression() {
857        let x = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
858        let y = array![3.0, 6.0, 9.0, 12.0]; // y = 2*x1 + 2*x2
859
860        let model = LinearRegression::elastic_net(0.1, 0.5)
861            .fit_intercept(false)
862            .fit(&x, &y)
863            .unwrap();
864
865        // Both coefficients should be shrunk but non-zero
866        println!(
867            "ElasticNet coef[0] = {}, coef[1] = {}",
868            model.coef()[0],
869            model.coef()[1]
870        );
871        assert!(model.coef()[0] > 0.0);
872        assert!(model.coef()[0] < 3.0); // More lenient bound for weak regularization
873        assert!(model.coef()[1] > 0.0);
874        assert!(model.coef()[1] < 3.0); // More lenient bound for weak regularization
875    }
876
877    #[test]
878    fn test_lasso_sparsity() {
879        // Create data where only first feature is relevant
880        let n_samples = 20;
881        let mut x = Array2::zeros((n_samples, 5));
882        let mut y = Array1::zeros(n_samples);
883
884        for i in 0..n_samples {
885            x[[i, 0]] = i as f64;
886            x[[i, 1]] = (i as f64) * 0.1; // weak feature
887                                          // Add deterministic noise instead of random
888            x[[i, 2]] = ((i * 7) % 10) as f64 / 10.0; // pseudo-random noise
889            x[[i, 3]] = ((i * 13) % 10) as f64 / 10.0; // pseudo-random noise
890            x[[i, 4]] = ((i * 17) % 10) as f64 / 10.0; // pseudo-random noise
891            y[i] = 2.0 * x[[i, 0]] + 0.05 * (i % 3) as f64;
892        }
893
894        // With strong L1 penalty, should select only the first feature
895        let model = LinearRegression::lasso(1.0)
896            .fit_intercept(false)
897            .fit(&x, &y)
898            .unwrap();
899
900        let coef = model.coef();
901
902        // First coefficient should be non-zero
903        assert!(coef[0] > 0.5);
904
905        // Other coefficients should be zero or very small
906        for i in 2..5 {
907            assert_abs_diff_eq!(coef[i], 0.0, epsilon = 0.01);
908        }
909    }
910
911    #[test]
912    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
913    fn test_linear_regression_early_stopping_lasso() {
914        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
915
916        // Create larger dataset for meaningful validation split
917        let n_samples = 100;
918        let n_features = 8;
919        let mut x = Array2::zeros((n_samples, n_features));
920        let mut y = Array1::zeros(n_samples);
921
922        // Generate synthetic data with linear relationship
923        for i in 0..n_samples {
924            for j in 0..n_features {
925                x[[i, j]] = (i * j + 1) as f64 / 20.0;
926            }
927            // Only first few features are relevant
928            y[i] = 2.0 * x[[i, 0]] + 1.5 * x[[i, 1]] + 0.8 * x[[i, 2]] + 0.1 * (i as f64 % 5.0);
929        }
930
931        let early_stopping_config = EarlyStoppingConfig {
932            criterion: StoppingCriterion::Patience(10),
933            validation_split: 0.25,
934            shuffle: true,
935            random_state: Some(42),
936            higher_is_better: true,
937            min_iterations: 5,
938            restore_best_weights: true,
939        };
940
941        let model = LinearRegression::lasso(0.1);
942        let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
943
944        assert!(result.is_ok());
945        let (fitted_model, validation_info) = result.unwrap();
946
947        // Check model properties
948        assert_eq!(fitted_model.coef().len(), n_features);
949        assert!(fitted_model.intercept().is_some());
950
951        // Check validation info
952        assert!(!validation_info.validation_scores.is_empty());
953        assert!(validation_info.best_score.is_some());
954        assert!(validation_info.best_iteration >= 1);
955
956        // Predictions should work
957        let predictions = fitted_model.predict(&x);
958        assert!(predictions.is_ok());
959        assert_eq!(predictions.unwrap().len(), n_samples);
960    }
961
962    #[test]
963    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
964    fn test_linear_regression_early_stopping_elastic_net() {
965        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
966
967        let x = array![
968            [1.0, 2.0, 0.5],
969            [2.0, 3.0, 1.0],
970            [3.0, 4.0, 1.5],
971            [4.0, 5.0, 2.0],
972            [5.0, 6.0, 2.5],
973            [6.0, 7.0, 3.0],
974            [7.0, 8.0, 3.5],
975            [8.0, 9.0, 4.0]
976        ];
977        let y = array![4.5, 7.0, 9.5, 12.0, 14.5, 17.0, 19.5, 22.0]; // y ≈ 1.5*x1 + x2 + x3
978
979        let early_stopping_config = EarlyStoppingConfig {
980            criterion: StoppingCriterion::TolerancePatience {
981                tolerance: 0.005,
982                patience: 3,
983            },
984            validation_split: 0.25,
985            shuffle: false,
986            random_state: Some(123),
987            higher_is_better: true,
988            min_iterations: 2,
989            restore_best_weights: true,
990        };
991
992        let model = LinearRegression::elastic_net(0.1, 0.7);
993        let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
994
995        assert!(result.is_ok());
996        let (fitted_model, validation_info) = result.unwrap();
997
998        assert_eq!(fitted_model.coef().len(), 3);
999        assert!(fitted_model.intercept().is_some());
1000        assert!(!validation_info.validation_scores.is_empty());
1001        assert!(validation_info.best_score.is_some());
1002    }
1003
1004    #[test]
1005    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1006    fn test_linear_regression_early_stopping_with_split() {
1007        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1008
1009        // Training data
1010        let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1011        let y_train = array![5.0, 8.0, 11.0, 14.0, 17.0]; // y = 2*x1 + x2
1012
1013        // Validation data
1014        let x_val = array![[6.0, 7.0], [7.0, 8.0]];
1015        let y_val = array![20.0, 23.0];
1016
1017        let early_stopping_config = EarlyStoppingConfig {
1018            criterion: StoppingCriterion::TargetScore(0.9),
1019            validation_split: 0.2, // Ignored since we provide split data
1020            shuffle: false,
1021            random_state: None,
1022            higher_is_better: true,
1023            min_iterations: 1,
1024            restore_best_weights: false,
1025        };
1026
1027        let model = LinearRegression::lasso(0.01);
1028        let result = model.fit_with_early_stopping_split(
1029            &x_train,
1030            &y_train,
1031            &x_val,
1032            &y_val,
1033            early_stopping_config,
1034        );
1035
1036        assert!(result.is_ok());
1037        let (fitted_model, validation_info) = result.unwrap();
1038
1039        assert_eq!(fitted_model.coef().len(), 2);
1040        assert!(fitted_model.intercept().is_some());
1041        assert!(!validation_info.validation_scores.is_empty());
1042
1043        // Coefficients should be close to true values [2, 1] with small regularization
1044        let coef = fitted_model.coef();
1045        assert!((coef[0] - 2.0).abs() < 0.5);
1046        assert!((coef[1] - 1.0).abs() < 0.5);
1047    }
1048
1049    #[test]
1050    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1051    fn test_linear_regression_early_stopping_ols() {
1052        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1053
1054        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1055        let y = array![3.0, 5.0, 7.0, 9.0, 11.0, 13.0]; // y = 2*x + 1
1056
1057        let early_stopping_config = EarlyStoppingConfig {
1058            criterion: StoppingCriterion::Patience(5),
1059            validation_split: 0.33,
1060            shuffle: false,
1061            random_state: None,
1062            higher_is_better: true,
1063            min_iterations: 1,
1064            restore_best_weights: true,
1065        };
1066
1067        // For OLS (no penalty), early stopping returns dummy validation info
1068        let model = LinearRegression::new().fit_intercept(true);
1069        let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1070
1071        assert!(result.is_ok());
1072        let (fitted_model, validation_info) = result.unwrap();
1073
1074        assert_eq!(fitted_model.coef().len(), 1);
1075        assert!(fitted_model.intercept().is_some());
1076
1077        // For OLS, validation info indicates no early stopping occurred
1078        assert!(!validation_info.stopped_early);
1079        assert!(validation_info.converged);
1080        assert_eq!(validation_info.best_iteration, 1);
1081
1082        // Model should still work correctly
1083        assert_abs_diff_eq!(fitted_model.coef()[0], 2.0, epsilon = 1e-10);
1084        assert_abs_diff_eq!(fitted_model.intercept().unwrap(), 1.0, epsilon = 1e-10);
1085    }
1086
1087    #[test]
1088    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1089    fn test_linear_regression_early_stopping_ridge() {
1090        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1091
1092        let x = array![
1093            [1.0, 0.5],
1094            [2.0, 1.0],
1095            [3.0, 1.5],
1096            [4.0, 2.0],
1097            [5.0, 2.5],
1098            [6.0, 3.0]
1099        ];
1100        let y = array![2.5, 4.0, 5.5, 7.0, 8.5, 10.0]; // y ≈ 1.5*x1 + x2
1101
1102        let early_stopping_config = EarlyStoppingConfig {
1103            criterion: StoppingCriterion::Patience(3),
1104            validation_split: 0.33,
1105            shuffle: true,
1106            random_state: Some(456),
1107            higher_is_better: true,
1108            min_iterations: 1,
1109            restore_best_weights: false,
1110        };
1111
1112        // For Ridge regression, early stopping currently returns dummy validation info
1113        let model = LinearRegression::new()
1114            .regularization(0.1)
1115            .fit_intercept(true);
1116        let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1117
1118        assert!(result.is_ok());
1119        let (fitted_model, validation_info) = result.unwrap();
1120
1121        assert_eq!(fitted_model.coef().len(), 2);
1122        assert!(fitted_model.intercept().is_some());
1123
1124        // For Ridge, early stopping is not fully implemented yet, so it should indicate convergence
1125        assert!(!validation_info.stopped_early);
1126        assert!(validation_info.converged);
1127    }
1128}