aprender/linear_model/
mod.rs

1//! Linear models for regression.
2//!
3//! Includes Ordinary Least Squares (OLS) and regularized regression.
4
5use crate::error::Result;
6use crate::metrics::r_squared;
7use crate::primitives::{Matrix, Vector};
8use crate::traits::Estimator;
9use serde::{Deserialize, Serialize};
10use std::fs;
11use std::path::Path;
12
13/// Ordinary Least Squares (OLS) linear regression.
14///
15/// Fits a linear model by minimizing the residual sum of squares between
16/// observed targets and predicted targets. The model equation is:
17///
18/// ```text
19/// y = X β + ε
20/// ```
21///
22/// where `β` is the coefficient vector and `ε` is random error.
23///
24/// # Solver
25///
26/// Uses normal equations: `β = (X^T X)^-1 X^T y` via Cholesky decomposition.
27///
28/// # Examples
29///
30/// ```
31/// use aprender::prelude::*;
32///
33/// // Simple linear regression: y = 2x + 1
34/// let x = Matrix::from_vec(4, 1, vec![
35///     1.0,
36///     2.0,
37///     3.0,
38///     4.0,
39/// ]).expect("Valid matrix dimensions");
40/// let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
41///
42/// let mut model = LinearRegression::new();
43/// model.fit(&x, &y).expect("Fit should succeed with valid data");
44///
45/// let predictions = model.predict(&x);
46/// let r2 = model.score(&x, &y);
47/// assert!(r2 > 0.99);
48/// ```
49///
50/// # Performance
51///
52/// - Time complexity: O(n²p + p³) where n = samples, p = features
53/// - Space complexity: O(np)
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct LinearRegression {
56    /// Coefficients for features (excluding intercept).
57    coefficients: Option<Vector<f32>>,
58    /// Intercept (bias) term.
59    intercept: f32,
60    /// Whether to fit an intercept.
61    fit_intercept: bool,
62}
63
64impl Default for LinearRegression {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl LinearRegression {
71    /// Creates a new `LinearRegression` with default settings.
72    #[must_use]
73    pub fn new() -> Self {
74        Self {
75            coefficients: None,
76            intercept: 0.0,
77            fit_intercept: true,
78        }
79    }
80
81    /// Sets whether to fit an intercept term.
82    #[must_use]
83    pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
84        self.fit_intercept = fit_intercept;
85        self
86    }
87
88    /// Returns the coefficients (excluding intercept).
89    ///
90    /// # Panics
91    ///
92    /// Panics if model is not fitted.
93    #[must_use]
94    pub fn coefficients(&self) -> &Vector<f32> {
95        self.coefficients
96            .as_ref()
97            .expect("Model not fitted. Call fit() first.")
98    }
99
100    /// Returns the intercept term.
101    #[must_use]
102    pub fn intercept(&self) -> f32 {
103        self.intercept
104    }
105
106    /// Returns true if the model has been fitted.
107    #[must_use]
108    pub fn is_fitted(&self) -> bool {
109        self.coefficients.is_some()
110    }
111
112    /// Saves the model to a binary file using bincode.
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if serialization or file writing fails.
117    pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
118        let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
119        fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
120        Ok(())
121    }
122
123    /// Loads a model from a binary file.
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if file reading or deserialization fails.
128    pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
129        let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
130        let model =
131            bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
132        Ok(model)
133    }
134
135    /// Saves the model to SafeTensors format.
136    ///
137    /// SafeTensors format is compatible with:
138    /// - HuggingFace ecosystem
139    /// - Ollama (can convert to GGUF)
140    /// - PyTorch, TensorFlow
141    /// - realizar inference engine
142    ///
143    /// # Errors
144    ///
145    /// Returns an error if:
146    /// - Model is not fitted
147    /// - Serialization fails
148    /// - File writing fails
149    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
150        use crate::serialization::safetensors;
151        use std::collections::BTreeMap;
152
153        // Verify model is fitted
154        let coefficients = self
155            .coefficients
156            .as_ref()
157            .ok_or("Cannot save unfitted model. Call fit() first.")?;
158
159        // Prepare tensors (BTreeMap ensures deterministic ordering)
160        let mut tensors = BTreeMap::new();
161
162        // Coefficients tensor
163        let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
164        let coef_shape = vec![coefficients.len()];
165        tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
166
167        // Intercept tensor
168        let intercept_data = vec![self.intercept];
169        let intercept_shape = vec![1];
170        tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
171
172        // Save to SafeTensors format
173        safetensors::save_safetensors(path, &tensors)?;
174        Ok(())
175    }
176
177    /// Loads a model from SafeTensors format.
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if:
182    /// - File reading fails
183    /// - SafeTensors format is invalid
184    /// - Required tensors are missing
185    pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
186        use crate::serialization::safetensors;
187
188        // Load SafeTensors file
189        let (metadata, raw_data) = safetensors::load_safetensors(path)?;
190
191        // Extract coefficients tensor
192        let coef_meta = metadata
193            .get("coefficients")
194            .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
195        let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
196
197        // Extract intercept tensor
198        let intercept_meta = metadata
199            .get("intercept")
200            .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
201        let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
202
203        // Validate intercept shape
204        if intercept_data.len() != 1 {
205            return Err(format!(
206                "Invalid intercept tensor: expected 1 value, got {}",
207                intercept_data.len()
208            ));
209        }
210
211        // Construct model
212        Ok(Self {
213            coefficients: Some(Vector::from_vec(coef_data)),
214            intercept: intercept_data[0],
215            fit_intercept: true, // Default to true for loaded models
216        })
217    }
218
219    /// Adds an intercept column of ones to the design matrix.
220    fn add_intercept_column(x: &Matrix<f32>) -> Matrix<f32> {
221        let (n_rows, n_cols) = x.shape();
222        let mut data = Vec::with_capacity(n_rows * (n_cols + 1));
223
224        for i in 0..n_rows {
225            data.push(1.0); // Intercept column
226            for j in 0..n_cols {
227                data.push(x.get(i, j));
228            }
229        }
230
231        Matrix::from_vec(n_rows, n_cols + 1, data)
232            .expect("Internal error: failed to create design matrix")
233    }
234}
235
236impl Estimator for LinearRegression {
237    /// Fits the linear regression model using normal equations.
238    ///
239    /// Solves: β = (X^T X)^-1 X^T y
240    ///
241    /// # Errors
242    ///
243    /// Returns an error if:
244    /// - Input dimensions don't match
245    /// - Not enough samples for the number of features (underdetermined system)
246    /// - Matrix is singular (not positive definite)
247    fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
248        let (n_samples, n_features) = x.shape();
249
250        if n_samples != y.len() {
251            return Err("Number of samples must match target length".into());
252        }
253
254        if n_samples == 0 {
255            return Err("Cannot fit with zero samples".into());
256        }
257
258        // Check for underdetermined system
259        // When fitting intercept, we need n_samples >= n_features + 1
260        // Without intercept, we need n_samples >= n_features
261        let required_samples = if self.fit_intercept {
262            n_features + 1
263        } else {
264            n_features
265        };
266
267        if n_samples < required_samples {
268            return Err(
269                "Insufficient samples: LinearRegression requires at least as many samples as \
270                 features (plus 1 if fitting intercept). Consider using Ridge regression or \
271                 collecting more training data"
272                    .into(),
273            );
274        }
275
276        // Create design matrix (with or without intercept)
277        let x_design = if self.fit_intercept {
278            Self::add_intercept_column(x)
279        } else {
280            x.clone()
281        };
282
283        // Compute X^T X
284        let xt = x_design.transpose();
285        let xtx = xt.matmul(&x_design)?;
286
287        // Compute X^T y
288        let xty = xt.matvec(y)?;
289
290        // Solve normal equations via Cholesky decomposition
291        let beta = xtx.cholesky_solve(&xty)?;
292
293        // Extract intercept and coefficients
294        if self.fit_intercept {
295            self.intercept = beta[0];
296            self.coefficients = Some(beta.slice(1, n_features + 1));
297        } else {
298            self.intercept = 0.0;
299            self.coefficients = Some(beta);
300        }
301
302        Ok(())
303    }
304
305    /// Predicts target values for input data.
306    ///
307    /// # Panics
308    ///
309    /// Panics if model is not fitted.
310    fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
311        let coefficients = self
312            .coefficients
313            .as_ref()
314            .expect("Model not fitted. Call fit() first.");
315
316        let result = x
317            .matvec(coefficients)
318            .expect("Matrix dimensions don't match coefficients");
319
320        result.add_scalar(self.intercept)
321    }
322
323    /// Computes the R² score.
324    fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
325        let y_pred = self.predict(x);
326        r_squared(&y_pred, y)
327    }
328}
329
330/// Ridge regression with L2 regularization.
331///
332/// Fits a linear model with L2 penalty on coefficient magnitudes.
333/// The optimization objective is:
334///
335/// ```text
336/// minimize ||y - Xβ||² + α||β||²
337/// ```
338///
339/// where `α` (alpha) controls the regularization strength.
340///
341/// # Solver
342///
343/// Uses regularized normal equations: `β = (X^T X + αI)^-1 X^T y`
344///
345/// # When to use Ridge
346///
347/// - When you have many correlated features (multicollinearity)
348/// - To prevent overfitting with limited samples
349/// - When all features are expected to contribute
350///
351/// # Examples
352///
353/// ```
354/// use aprender::prelude::*;
355/// use aprender::linear_model::Ridge;
356///
357/// // Data with some noise
358/// let x = Matrix::from_vec(5, 2, vec![
359///     1.0, 2.0,
360///     2.0, 3.0,
361///     3.0, 4.0,
362///     4.0, 5.0,
363///     5.0, 6.0,
364/// ]).expect("Valid matrix dimensions");
365/// let y = Vector::from_slice(&[5.0, 8.0, 11.0, 14.0, 17.0]);
366///
367/// let mut model = Ridge::new(1.0);  // alpha = 1.0
368/// model.fit(&x, &y).expect("Fit should succeed with valid data");
369///
370/// let predictions = model.predict(&x);
371/// let r2 = model.score(&x, &y);
372/// assert!(r2 > 0.9);
373/// ```
374#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct Ridge {
376    /// Regularization strength (lambda/alpha).
377    alpha: f32,
378    /// Coefficients for features (excluding intercept).
379    coefficients: Option<Vector<f32>>,
380    /// Intercept (bias) term.
381    intercept: f32,
382    /// Whether to fit an intercept.
383    fit_intercept: bool,
384}
385
386impl Ridge {
387    /// Creates a new `Ridge` regression with the given regularization strength.
388    ///
389    /// # Arguments
390    ///
391    /// * `alpha` - Regularization strength. Larger values = more regularization.
392    ///   Must be non-negative. Use 0.0 for no regularization (equivalent to OLS).
393    #[must_use]
394    pub fn new(alpha: f32) -> Self {
395        Self {
396            alpha,
397            coefficients: None,
398            intercept: 0.0,
399            fit_intercept: true,
400        }
401    }
402
403    /// Sets whether to fit an intercept term.
404    #[must_use]
405    pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
406        self.fit_intercept = fit_intercept;
407        self
408    }
409
410    /// Returns the regularization strength (alpha).
411    #[must_use]
412    pub fn alpha(&self) -> f32 {
413        self.alpha
414    }
415
416    /// Returns the coefficients (excluding intercept).
417    ///
418    /// # Panics
419    ///
420    /// Panics if model is not fitted.
421    #[must_use]
422    pub fn coefficients(&self) -> &Vector<f32> {
423        self.coefficients
424            .as_ref()
425            .expect("Model not fitted. Call fit() first.")
426    }
427
428    /// Returns the intercept term.
429    #[must_use]
430    pub fn intercept(&self) -> f32 {
431        self.intercept
432    }
433
434    /// Returns true if the model has been fitted.
435    #[must_use]
436    pub fn is_fitted(&self) -> bool {
437        self.coefficients.is_some()
438    }
439
440    /// Saves the model to a binary file using bincode.
441    ///
442    /// # Errors
443    ///
444    /// Returns an error if serialization or file writing fails.
445    pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
446        let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
447        fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
448        Ok(())
449    }
450
451    /// Loads a model from a binary file.
452    ///
453    /// # Errors
454    ///
455    /// Returns an error if file reading or deserialization fails.
456    pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
457        let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
458        let model =
459            bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
460        Ok(model)
461    }
462
463    /// Saves the model to SafeTensors format.
464    ///
465    /// SafeTensors format is compatible with:
466    /// - HuggingFace ecosystem
467    /// - Ollama (can convert to GGUF)
468    /// - PyTorch, TensorFlow
469    /// - realizar inference engine
470    ///
471    /// # Errors
472    ///
473    /// Returns an error if:
474    /// - Model is not fitted
475    /// - Serialization fails
476    /// - File writing fails
477    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
478        use crate::serialization::safetensors;
479        use std::collections::BTreeMap;
480
481        // Verify model is fitted
482        let coefficients = self
483            .coefficients
484            .as_ref()
485            .ok_or("Cannot save unfitted model. Call fit() first.")?;
486
487        // Prepare tensors (BTreeMap ensures deterministic ordering)
488        let mut tensors = BTreeMap::new();
489
490        // Coefficients tensor
491        let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
492        let coef_shape = vec![coefficients.len()];
493        tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
494
495        // Intercept tensor
496        let intercept_data = vec![self.intercept];
497        let intercept_shape = vec![1];
498        tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
499
500        // Alpha (regularization strength) as tensor
501        let alpha_data = vec![self.alpha];
502        let alpha_shape = vec![1];
503        tensors.insert("alpha".to_string(), (alpha_data, alpha_shape));
504
505        // Save to SafeTensors format
506        safetensors::save_safetensors(path, &tensors)?;
507        Ok(())
508    }
509
510    /// Loads a model from SafeTensors format.
511    ///
512    /// # Errors
513    ///
514    /// Returns an error if:
515    /// - File reading fails
516    /// - SafeTensors format is invalid
517    /// - Required tensors are missing
518    pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
519        use crate::serialization::safetensors;
520
521        // Load SafeTensors file
522        let (metadata, raw_data) = safetensors::load_safetensors(path)?;
523
524        // Extract coefficients tensor
525        let coef_meta = metadata
526            .get("coefficients")
527            .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
528        let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
529
530        // Extract intercept tensor
531        let intercept_meta = metadata
532            .get("intercept")
533            .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
534        let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
535
536        // Extract alpha tensor
537        let alpha_meta = metadata
538            .get("alpha")
539            .ok_or("Missing 'alpha' tensor in SafeTensors file")?;
540        let alpha_data = safetensors::extract_tensor(&raw_data, alpha_meta)?;
541
542        // Validate tensor sizes
543        if intercept_data.len() != 1 {
544            return Err(format!(
545                "Expected intercept tensor to have 1 element, got {}",
546                intercept_data.len()
547            ));
548        }
549
550        if alpha_data.len() != 1 {
551            return Err(format!(
552                "Expected alpha tensor to have 1 element, got {}",
553                alpha_data.len()
554            ));
555        }
556
557        // Reconstruct model
558        Ok(Self {
559            alpha: alpha_data[0],
560            coefficients: Some(Vector::from_vec(coef_data)),
561            intercept: intercept_data[0],
562            fit_intercept: true, // Default to true for loaded models
563        })
564    }
565}
566
567impl Estimator for Ridge {
568    /// Fits the Ridge regression model using regularized normal equations.
569    ///
570    /// Solves: β = (X^T X + αI)^-1 X^T y
571    ///
572    /// # Errors
573    ///
574    /// Returns an error if input dimensions don't match or matrix is singular.
575    fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
576        let (n_samples, n_features) = x.shape();
577
578        if n_samples != y.len() {
579            return Err("Number of samples must match target length".into());
580        }
581
582        if n_samples == 0 {
583            return Err("Cannot fit with zero samples".into());
584        }
585
586        // Create design matrix (with or without intercept)
587        let x_design = if self.fit_intercept {
588            LinearRegression::add_intercept_column(x)
589        } else {
590            x.clone()
591        };
592
593        let n_params = if self.fit_intercept {
594            n_features + 1
595        } else {
596            n_features
597        };
598
599        // Compute X^T X
600        let xt = x_design.transpose();
601        let mut xtx = xt.matmul(&x_design)?;
602
603        // Add regularization: X^T X + αI
604        // Note: We don't regularize the intercept term
605        for i in 0..n_params {
606            // Skip intercept if fitting intercept (first column)
607            if self.fit_intercept && i == 0 {
608                continue;
609            }
610            let current = xtx.get(i, i);
611            xtx.set(i, i, current + self.alpha);
612        }
613
614        // Compute X^T y
615        let xty = xt.matvec(y)?;
616
617        // Solve regularized normal equations via Cholesky decomposition
618        let beta = xtx.cholesky_solve(&xty)?;
619
620        // Extract intercept and coefficients
621        if self.fit_intercept {
622            self.intercept = beta[0];
623            self.coefficients = Some(beta.slice(1, n_features + 1));
624        } else {
625            self.intercept = 0.0;
626            self.coefficients = Some(beta);
627        }
628
629        Ok(())
630    }
631
632    /// Predicts target values for input data.
633    ///
634    /// # Panics
635    ///
636    /// Panics if model is not fitted.
637    fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
638        let coefficients = self
639            .coefficients
640            .as_ref()
641            .expect("Model not fitted. Call fit() first.");
642
643        let result = x
644            .matvec(coefficients)
645            .expect("Matrix dimensions don't match coefficients");
646
647        result.add_scalar(self.intercept)
648    }
649
650    /// Computes the R² score.
651    fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
652        let y_pred = self.predict(x);
653        r_squared(&y_pred, y)
654    }
655}
656
657/// Lasso regression with L1 regularization.
658///
659/// Fits a linear model with L1 penalty on coefficient magnitudes.
660/// The optimization objective is:
661///
662/// ```text
663/// minimize ||y - Xβ||² + α||β||₁
664/// ```
665///
666/// where `α` (alpha) controls the regularization strength.
667///
668/// # Solver
669///
670/// Uses coordinate descent with soft-thresholding.
671///
672/// # When to use Lasso
673///
674/// - For automatic feature selection (produces sparse models)
675/// - When you expect only a few features to be relevant
676/// - When interpretability through sparsity is desired
677///
678/// # Examples
679///
680/// ```
681/// use aprender::prelude::*;
682/// use aprender::linear_model::Lasso;
683///
684/// // Data with some features
685/// let x = Matrix::from_vec(5, 2, vec![
686///     1.0, 2.0,
687///     2.0, 3.0,
688///     3.0, 4.0,
689///     4.0, 5.0,
690///     5.0, 6.0,
691/// ]).expect("Valid matrix dimensions");
692/// let y = Vector::from_slice(&[5.0, 8.0, 11.0, 14.0, 17.0]);
693///
694/// let mut model = Lasso::new(0.1);  // alpha = 0.1
695/// model.fit(&x, &y).expect("Fit should succeed with valid data");
696///
697/// let predictions = model.predict(&x);
698/// let r2 = model.score(&x, &y);
699/// assert!(r2 > 0.9);
700/// ```
701#[derive(Debug, Clone, Serialize, Deserialize)]
702pub struct Lasso {
703    /// Regularization strength (lambda/alpha).
704    alpha: f32,
705    /// Coefficients for features (excluding intercept).
706    coefficients: Option<Vector<f32>>,
707    /// Intercept (bias) term.
708    intercept: f32,
709    /// Whether to fit an intercept.
710    fit_intercept: bool,
711    /// Maximum number of iterations for coordinate descent.
712    max_iter: usize,
713    /// Tolerance for convergence.
714    tol: f32,
715}
716
717impl Lasso {
718    /// Creates a new `Lasso` regression with the given regularization strength.
719    ///
720    /// # Arguments
721    ///
722    /// * `alpha` - Regularization strength. Larger values = more sparsity.
723    ///   Must be non-negative.
724    #[must_use]
725    pub fn new(alpha: f32) -> Self {
726        Self {
727            alpha,
728            coefficients: None,
729            intercept: 0.0,
730            fit_intercept: true,
731            max_iter: 1000,
732            tol: 1e-4,
733        }
734    }
735
736    /// Sets whether to fit an intercept term.
737    #[must_use]
738    pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
739        self.fit_intercept = fit_intercept;
740        self
741    }
742
743    /// Sets the maximum number of iterations.
744    #[must_use]
745    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
746        self.max_iter = max_iter;
747        self
748    }
749
750    /// Sets the convergence tolerance.
751    #[must_use]
752    pub fn with_tol(mut self, tol: f32) -> Self {
753        self.tol = tol;
754        self
755    }
756
757    /// Returns the regularization strength (alpha).
758    #[must_use]
759    pub fn alpha(&self) -> f32 {
760        self.alpha
761    }
762
763    /// Returns the coefficients (excluding intercept).
764    ///
765    /// # Panics
766    ///
767    /// Panics if model is not fitted.
768    #[must_use]
769    pub fn coefficients(&self) -> &Vector<f32> {
770        self.coefficients
771            .as_ref()
772            .expect("Model not fitted. Call fit() first.")
773    }
774
775    /// Returns the intercept term.
776    #[must_use]
777    pub fn intercept(&self) -> f32 {
778        self.intercept
779    }
780
781    /// Returns true if the model has been fitted.
782    #[must_use]
783    pub fn is_fitted(&self) -> bool {
784        self.coefficients.is_some()
785    }
786
787    /// Soft-thresholding operator for L1 regularization.
788    fn soft_threshold(x: f32, lambda: f32) -> f32 {
789        if x > lambda {
790            x - lambda
791        } else if x < -lambda {
792            x + lambda
793        } else {
794            0.0
795        }
796    }
797
798    /// Saves the model to a binary file using bincode.
799    ///
800    /// # Errors
801    ///
802    /// Returns an error if serialization or file writing fails.
803    pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
804        let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
805        fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
806        Ok(())
807    }
808
809    /// Loads a model from a binary file.
810    ///
811    /// # Errors
812    ///
813    /// Returns an error if file reading or deserialization fails.
814    pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
815        let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
816        let model =
817            bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
818        Ok(model)
819    }
820
821    /// Saves the model to SafeTensors format.
822    ///
823    /// SafeTensors format is compatible with:
824    /// - HuggingFace ecosystem
825    /// - Ollama (can convert to GGUF)
826    /// - PyTorch, TensorFlow
827    /// - realizar inference engine
828    ///
829    /// # Errors
830    ///
831    /// Returns an error if:
832    /// - Model is not fitted
833    /// - Serialization fails
834    /// - File writing fails
835    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
836        use crate::serialization::safetensors;
837        use std::collections::BTreeMap;
838
839        // Verify model is fitted
840        let coefficients = self
841            .coefficients
842            .as_ref()
843            .ok_or("Cannot save unfitted model. Call fit() first.")?;
844
845        // Prepare tensors (BTreeMap ensures deterministic ordering)
846        let mut tensors = BTreeMap::new();
847
848        // Coefficients tensor
849        let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
850        let coef_shape = vec![coefficients.len()];
851        tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
852
853        // Intercept tensor
854        let intercept_data = vec![self.intercept];
855        let intercept_shape = vec![1];
856        tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
857
858        // Alpha (regularization strength) as tensor
859        let alpha_data = vec![self.alpha];
860        let alpha_shape = vec![1];
861        tensors.insert("alpha".to_string(), (alpha_data, alpha_shape));
862
863        // Max iterations as tensor (stored as f32 for consistency)
864        let max_iter_data = vec![self.max_iter as f32];
865        let max_iter_shape = vec![1];
866        tensors.insert("max_iter".to_string(), (max_iter_data, max_iter_shape));
867
868        // Tolerance as tensor
869        let tol_data = vec![self.tol];
870        let tol_shape = vec![1];
871        tensors.insert("tol".to_string(), (tol_data, tol_shape));
872
873        // Save to SafeTensors format
874        safetensors::save_safetensors(path, &tensors)?;
875        Ok(())
876    }
877
878    /// Loads a model from SafeTensors format.
879    ///
880    /// # Errors
881    ///
882    /// Returns an error if:
883    /// - File reading fails
884    /// - SafeTensors format is invalid
885    /// - Required tensors are missing
886    pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
887        use crate::serialization::safetensors;
888
889        // Load SafeTensors file
890        let (metadata, raw_data) = safetensors::load_safetensors(path)?;
891
892        // Extract coefficients tensor
893        let coef_meta = metadata
894            .get("coefficients")
895            .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
896        let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
897
898        // Extract intercept tensor
899        let intercept_meta = metadata
900            .get("intercept")
901            .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
902        let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
903
904        // Extract alpha tensor
905        let alpha_meta = metadata
906            .get("alpha")
907            .ok_or("Missing 'alpha' tensor in SafeTensors file")?;
908        let alpha_data = safetensors::extract_tensor(&raw_data, alpha_meta)?;
909
910        // Extract max_iter tensor
911        let max_iter_meta = metadata
912            .get("max_iter")
913            .ok_or("Missing 'max_iter' tensor in SafeTensors file")?;
914        let max_iter_data = safetensors::extract_tensor(&raw_data, max_iter_meta)?;
915
916        // Extract tol tensor
917        let tol_meta = metadata
918            .get("tol")
919            .ok_or("Missing 'tol' tensor in SafeTensors file")?;
920        let tol_data = safetensors::extract_tensor(&raw_data, tol_meta)?;
921
922        // Validate tensor sizes
923        if intercept_data.len() != 1 {
924            return Err(format!(
925                "Expected intercept tensor to have 1 element, got {}",
926                intercept_data.len()
927            ));
928        }
929
930        if alpha_data.len() != 1 {
931            return Err(format!(
932                "Expected alpha tensor to have 1 element, got {}",
933                alpha_data.len()
934            ));
935        }
936
937        if max_iter_data.len() != 1 {
938            return Err(format!(
939                "Expected max_iter tensor to have 1 element, got {}",
940                max_iter_data.len()
941            ));
942        }
943
944        if tol_data.len() != 1 {
945            return Err(format!(
946                "Expected tol tensor to have 1 element, got {}",
947                tol_data.len()
948            ));
949        }
950
951        // Reconstruct model
952        Ok(Self {
953            alpha: alpha_data[0],
954            coefficients: Some(Vector::from_vec(coef_data)),
955            intercept: intercept_data[0],
956            fit_intercept: true, // Default to true for loaded models
957            max_iter: max_iter_data[0] as usize,
958            tol: tol_data[0],
959        })
960    }
961}
962
963impl Estimator for Lasso {
964    /// Fits the Lasso regression model using coordinate descent.
965    ///
966    /// # Errors
967    ///
968    /// Returns an error if input dimensions don't match.
969    fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
970        let (n_samples, n_features) = x.shape();
971
972        if n_samples != y.len() {
973            return Err("Number of samples must match target length".into());
974        }
975
976        if n_samples == 0 {
977            return Err("Cannot fit with zero samples".into());
978        }
979
980        // Center data if fitting intercept
981        let (x_centered, y_centered, y_mean) = if self.fit_intercept {
982            // Compute means
983            let mut x_mean = vec![0.0; n_features];
984            let mut y_sum = 0.0;
985
986            for i in 0..n_samples {
987                for (j, mean_j) in x_mean.iter_mut().enumerate() {
988                    *mean_j += x.get(i, j);
989                }
990                y_sum += y[i];
991            }
992
993            for mean in &mut x_mean {
994                *mean /= n_samples as f32;
995            }
996            let y_mean = y_sum / n_samples as f32;
997
998            // Center data
999            let mut x_data = vec![0.0; n_samples * n_features];
1000            let mut y_data = vec![0.0; n_samples];
1001
1002            for i in 0..n_samples {
1003                for j in 0..n_features {
1004                    x_data[i * n_features + j] = x.get(i, j) - x_mean[j];
1005                }
1006                y_data[i] = y[i] - y_mean;
1007            }
1008
1009            (
1010                Matrix::from_vec(n_samples, n_features, x_data)
1011                    .expect("Valid matrix dimensions for property test"),
1012                Vector::from_vec(y_data),
1013                y_mean,
1014            )
1015        } else {
1016            (x.clone(), y.clone(), 0.0)
1017        };
1018
1019        // Initialize coefficients to zero
1020        let mut beta = vec![0.0; n_features];
1021
1022        // Precompute X^T X diagonal (column norms squared)
1023        let mut col_norms_sq = vec![0.0; n_features];
1024        for (j, norm_sq) in col_norms_sq.iter_mut().enumerate() {
1025            for i in 0..n_samples {
1026                let val = x_centered.get(i, j);
1027                *norm_sq += val * val;
1028            }
1029        }
1030
1031        // Coordinate descent
1032        for _ in 0..self.max_iter {
1033            let mut max_change = 0.0f32;
1034
1035            for j in 0..n_features {
1036                if col_norms_sq[j] < 1e-10 {
1037                    continue; // Skip zero-variance features
1038                }
1039
1040                // Compute residual without current feature
1041                let mut rho = 0.0;
1042                for i in 0..n_samples {
1043                    let mut pred = 0.0;
1044                    for (k, &beta_k) in beta.iter().enumerate() {
1045                        if k != j {
1046                            pred += x_centered.get(i, k) * beta_k;
1047                        }
1048                    }
1049                    let residual = y_centered[i] - pred;
1050                    rho += x_centered.get(i, j) * residual;
1051                }
1052
1053                // Update coefficient with soft-thresholding
1054                let old_beta = beta[j];
1055                beta[j] = Self::soft_threshold(rho, self.alpha) / col_norms_sq[j];
1056
1057                let change = (beta[j] - old_beta).abs();
1058                if change > max_change {
1059                    max_change = change;
1060                }
1061            }
1062
1063            // Check convergence
1064            if max_change < self.tol {
1065                break;
1066            }
1067        }
1068
1069        // Set intercept
1070        if self.fit_intercept {
1071            let mut intercept = y_mean;
1072            let mut x_mean = vec![0.0; n_features];
1073            for j in 0..n_features {
1074                for i in 0..n_samples {
1075                    x_mean[j] += x.get(i, j);
1076                }
1077                x_mean[j] /= n_samples as f32;
1078                intercept -= beta[j] * x_mean[j];
1079            }
1080            self.intercept = intercept;
1081        } else {
1082            self.intercept = 0.0;
1083        }
1084
1085        self.coefficients = Some(Vector::from_vec(beta));
1086        Ok(())
1087    }
1088
1089    /// Predicts target values for input data.
1090    ///
1091    /// # Panics
1092    ///
1093    /// Panics if model is not fitted.
1094    fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
1095        let coefficients = self
1096            .coefficients
1097            .as_ref()
1098            .expect("Model not fitted. Call fit() first.");
1099
1100        let result = x
1101            .matvec(coefficients)
1102            .expect("Matrix dimensions don't match coefficients");
1103
1104        result.add_scalar(self.intercept)
1105    }
1106
1107    /// Computes the R² score.
1108    fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
1109        let y_pred = self.predict(x);
1110        r_squared(&y_pred, y)
1111    }
1112}
1113
1114/// Elastic Net regression with combined L1 and L2 regularization.
1115///
1116/// Fits a linear model with both L1 and L2 penalties:
1117///
1118/// ```text
1119/// minimize ||y - Xβ||² + α * l1_ratio * ||β||₁ + α * (1 - l1_ratio) * ||β||²
1120/// ```
1121///
1122/// # Parameters
1123///
1124/// - `alpha` - Overall regularization strength
1125/// - `l1_ratio` - Mix between L1 and L2 (0.0 = Ridge, 1.0 = Lasso)
1126///
1127/// # Solver
1128///
1129/// Uses coordinate descent with combined soft-thresholding and shrinkage.
1130///
1131/// # When to use Elastic Net
1132///
1133/// - When you want both sparsity (L1) and grouping effect (L2)
1134/// - With correlated features where Lasso may be unstable
1135/// - When you don't know which regularization type to use
1136///
1137/// # Examples
1138///
1139/// ```
1140/// use aprender::prelude::*;
1141/// use aprender::linear_model::ElasticNet;
1142///
1143/// let x = Matrix::from_vec(5, 2, vec![
1144///     1.0, 2.0,
1145///     2.0, 3.0,
1146///     3.0, 4.0,
1147///     4.0, 5.0,
1148///     5.0, 6.0,
1149/// ]).expect("Valid matrix dimensions");
1150/// let y = Vector::from_slice(&[5.0, 8.0, 11.0, 14.0, 17.0]);
1151///
1152/// // 50% L1, 50% L2
1153/// let mut model = ElasticNet::new(0.1, 0.5);
1154/// model.fit(&x, &y).expect("Fit should succeed with valid data");
1155///
1156/// let r2 = model.score(&x, &y);
1157/// assert!(r2 > 0.9);
1158/// ```
1159#[derive(Debug, Clone, Serialize, Deserialize)]
1160pub struct ElasticNet {
1161    /// Overall regularization strength.
1162    alpha: f32,
1163    /// Mix between L1 and L2 (0.0 = pure L2, 1.0 = pure L1).
1164    l1_ratio: f32,
1165    /// Coefficients for features (excluding intercept).
1166    coefficients: Option<Vector<f32>>,
1167    /// Intercept (bias) term.
1168    intercept: f32,
1169    /// Whether to fit an intercept.
1170    fit_intercept: bool,
1171    /// Maximum number of iterations.
1172    max_iter: usize,
1173    /// Tolerance for convergence.
1174    tol: f32,
1175}
1176
1177impl ElasticNet {
1178    /// Creates a new `ElasticNet` with the given parameters.
1179    ///
1180    /// # Arguments
1181    ///
1182    /// * `alpha` - Overall regularization strength
1183    /// * `l1_ratio` - Mix between L1 and L2 (0.0 = Ridge, 1.0 = Lasso)
1184    #[must_use]
1185    pub fn new(alpha: f32, l1_ratio: f32) -> Self {
1186        Self {
1187            alpha,
1188            l1_ratio: l1_ratio.clamp(0.0, 1.0),
1189            coefficients: None,
1190            intercept: 0.0,
1191            fit_intercept: true,
1192            max_iter: 1000,
1193            tol: 1e-4,
1194        }
1195    }
1196
1197    /// Sets whether to fit an intercept term.
1198    #[must_use]
1199    pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
1200        self.fit_intercept = fit_intercept;
1201        self
1202    }
1203
1204    /// Sets the maximum number of iterations.
1205    #[must_use]
1206    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1207        self.max_iter = max_iter;
1208        self
1209    }
1210
1211    /// Sets the convergence tolerance.
1212    #[must_use]
1213    pub fn with_tol(mut self, tol: f32) -> Self {
1214        self.tol = tol;
1215        self
1216    }
1217
1218    /// Returns the regularization strength (alpha).
1219    #[must_use]
1220    pub fn alpha(&self) -> f32 {
1221        self.alpha
1222    }
1223
1224    /// Returns the L1/L2 ratio.
1225    #[must_use]
1226    pub fn l1_ratio(&self) -> f32 {
1227        self.l1_ratio
1228    }
1229
1230    /// Returns the coefficients (excluding intercept).
1231    ///
1232    /// # Panics
1233    ///
1234    /// Panics if model is not fitted.
1235    #[must_use]
1236    pub fn coefficients(&self) -> &Vector<f32> {
1237        self.coefficients
1238            .as_ref()
1239            .expect("Model not fitted. Call fit() first.")
1240    }
1241
1242    /// Returns the intercept term.
1243    #[must_use]
1244    pub fn intercept(&self) -> f32 {
1245        self.intercept
1246    }
1247
1248    /// Returns true if the model has been fitted.
1249    #[must_use]
1250    pub fn is_fitted(&self) -> bool {
1251        self.coefficients.is_some()
1252    }
1253
1254    /// Saves the model to a binary file.
1255    ///
1256    /// # Errors
1257    ///
1258    /// Returns an error if serialization or file writing fails.
1259    pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
1260        let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
1261        fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
1262        Ok(())
1263    }
1264
1265    /// Loads a model from a binary file.
1266    ///
1267    /// # Errors
1268    ///
1269    /// Returns an error if file reading or deserialization fails.
1270    pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
1271        let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
1272        let model =
1273            bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
1274        Ok(model)
1275    }
1276
1277    /// Saves the model to SafeTensors format.
1278    ///
1279    /// SafeTensors format is compatible with:
1280    /// - HuggingFace ecosystem
1281    /// - Ollama (can convert to GGUF)
1282    /// - PyTorch, TensorFlow
1283    /// - realizar inference engine
1284    ///
1285    /// # Errors
1286    ///
1287    /// Returns an error if:
1288    /// - Model is not fitted
1289    /// - Serialization fails
1290    /// - File writing fails
1291    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
1292        use crate::serialization::safetensors;
1293        use std::collections::BTreeMap;
1294
1295        // Verify model is fitted
1296        let coefficients = self
1297            .coefficients
1298            .as_ref()
1299            .ok_or("Cannot save unfitted model. Call fit() first.")?;
1300
1301        // Prepare tensors (BTreeMap ensures deterministic ordering)
1302        let mut tensors = BTreeMap::new();
1303
1304        // Coefficients tensor
1305        let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
1306        let coef_shape = vec![coefficients.len()];
1307        tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
1308
1309        // Intercept tensor
1310        let intercept_data = vec![self.intercept];
1311        let intercept_shape = vec![1];
1312        tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
1313
1314        // Alpha (regularization strength) as tensor
1315        let alpha_data = vec![self.alpha];
1316        let alpha_shape = vec![1];
1317        tensors.insert("alpha".to_string(), (alpha_data, alpha_shape));
1318
1319        // L1 ratio as tensor
1320        let l1_ratio_data = vec![self.l1_ratio];
1321        let l1_ratio_shape = vec![1];
1322        tensors.insert("l1_ratio".to_string(), (l1_ratio_data, l1_ratio_shape));
1323
1324        // Max iterations as tensor (stored as f32 for consistency)
1325        let max_iter_data = vec![self.max_iter as f32];
1326        let max_iter_shape = vec![1];
1327        tensors.insert("max_iter".to_string(), (max_iter_data, max_iter_shape));
1328
1329        // Tolerance as tensor
1330        let tol_data = vec![self.tol];
1331        let tol_shape = vec![1];
1332        tensors.insert("tol".to_string(), (tol_data, tol_shape));
1333
1334        // Save to SafeTensors format
1335        safetensors::save_safetensors(path, &tensors)?;
1336        Ok(())
1337    }
1338
1339    /// Loads a model from SafeTensors format.
1340    ///
1341    /// # Errors
1342    ///
1343    /// Returns an error if:
1344    /// - File reading fails
1345    /// - SafeTensors format is invalid
1346    /// - Required tensors are missing
1347    pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
1348        use crate::serialization::safetensors;
1349
1350        // Load SafeTensors file
1351        let (metadata, raw_data) = safetensors::load_safetensors(path)?;
1352
1353        // Extract coefficients tensor
1354        let coef_meta = metadata
1355            .get("coefficients")
1356            .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
1357        let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
1358
1359        // Extract intercept tensor
1360        let intercept_meta = metadata
1361            .get("intercept")
1362            .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
1363        let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
1364
1365        // Extract alpha tensor
1366        let alpha_meta = metadata
1367            .get("alpha")
1368            .ok_or("Missing 'alpha' tensor in SafeTensors file")?;
1369        let alpha_data = safetensors::extract_tensor(&raw_data, alpha_meta)?;
1370
1371        // Extract l1_ratio tensor
1372        let l1_ratio_meta = metadata
1373            .get("l1_ratio")
1374            .ok_or("Missing 'l1_ratio' tensor in SafeTensors file")?;
1375        let l1_ratio_data = safetensors::extract_tensor(&raw_data, l1_ratio_meta)?;
1376
1377        // Extract max_iter tensor
1378        let max_iter_meta = metadata
1379            .get("max_iter")
1380            .ok_or("Missing 'max_iter' tensor in SafeTensors file")?;
1381        let max_iter_data = safetensors::extract_tensor(&raw_data, max_iter_meta)?;
1382
1383        // Extract tol tensor
1384        let tol_meta = metadata
1385            .get("tol")
1386            .ok_or("Missing 'tol' tensor in SafeTensors file")?;
1387        let tol_data = safetensors::extract_tensor(&raw_data, tol_meta)?;
1388
1389        // Validate tensor sizes
1390        if intercept_data.len() != 1 {
1391            return Err(format!(
1392                "Expected intercept tensor to have 1 element, got {}",
1393                intercept_data.len()
1394            ));
1395        }
1396
1397        if alpha_data.len() != 1 {
1398            return Err(format!(
1399                "Expected alpha tensor to have 1 element, got {}",
1400                alpha_data.len()
1401            ));
1402        }
1403
1404        if l1_ratio_data.len() != 1 {
1405            return Err(format!(
1406                "Expected l1_ratio tensor to have 1 element, got {}",
1407                l1_ratio_data.len()
1408            ));
1409        }
1410
1411        if max_iter_data.len() != 1 {
1412            return Err(format!(
1413                "Expected max_iter tensor to have 1 element, got {}",
1414                max_iter_data.len()
1415            ));
1416        }
1417
1418        if tol_data.len() != 1 {
1419            return Err(format!(
1420                "Expected tol tensor to have 1 element, got {}",
1421                tol_data.len()
1422            ));
1423        }
1424
1425        // Reconstruct model
1426        Ok(Self {
1427            alpha: alpha_data[0],
1428            l1_ratio: l1_ratio_data[0],
1429            coefficients: Some(Vector::from_vec(coef_data)),
1430            intercept: intercept_data[0],
1431            fit_intercept: true, // Default to true for loaded models
1432            max_iter: max_iter_data[0] as usize,
1433            tol: tol_data[0],
1434        })
1435    }
1436}
1437
1438impl Estimator for ElasticNet {
1439    /// Fits the Elastic Net model using coordinate descent.
1440    ///
1441    /// # Errors
1442    ///
1443    /// Returns an error if input dimensions don't match.
1444    fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
1445        let (n_samples, n_features) = x.shape();
1446
1447        if n_samples != y.len() {
1448            return Err("Number of samples must match target length".into());
1449        }
1450
1451        if n_samples == 0 {
1452            return Err("Cannot fit with zero samples".into());
1453        }
1454
1455        // Center data if fitting intercept
1456        let (x_centered, y_centered, y_mean) = if self.fit_intercept {
1457            let mut x_mean = vec![0.0; n_features];
1458            let mut y_sum = 0.0;
1459
1460            for i in 0..n_samples {
1461                for (j, mean_j) in x_mean.iter_mut().enumerate() {
1462                    *mean_j += x.get(i, j);
1463                }
1464                y_sum += y[i];
1465            }
1466
1467            for mean in &mut x_mean {
1468                *mean /= n_samples as f32;
1469            }
1470            let y_mean = y_sum / n_samples as f32;
1471
1472            let mut x_data = vec![0.0; n_samples * n_features];
1473            let mut y_data = vec![0.0; n_samples];
1474
1475            for i in 0..n_samples {
1476                for j in 0..n_features {
1477                    x_data[i * n_features + j] = x.get(i, j) - x_mean[j];
1478                }
1479                y_data[i] = y[i] - y_mean;
1480            }
1481
1482            (
1483                Matrix::from_vec(n_samples, n_features, x_data)
1484                    .expect("Valid matrix dimensions for property test"),
1485                Vector::from_vec(y_data),
1486                y_mean,
1487            )
1488        } else {
1489            (x.clone(), y.clone(), 0.0)
1490        };
1491
1492        // Initialize coefficients
1493        let mut beta = vec![0.0; n_features];
1494
1495        // Precompute column norms squared
1496        let mut col_norms_sq = vec![0.0; n_features];
1497        for (j, norm_sq) in col_norms_sq.iter_mut().enumerate() {
1498            for i in 0..n_samples {
1499                let val = x_centered.get(i, j);
1500                *norm_sq += val * val;
1501            }
1502        }
1503
1504        // L1 and L2 penalties
1505        let l1_penalty = self.alpha * self.l1_ratio;
1506        let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
1507
1508        // Coordinate descent
1509        for _ in 0..self.max_iter {
1510            let mut max_change = 0.0f32;
1511
1512            for j in 0..n_features {
1513                if col_norms_sq[j] < 1e-10 {
1514                    continue;
1515                }
1516
1517                // Compute residual without current feature
1518                let mut rho = 0.0;
1519                for i in 0..n_samples {
1520                    let mut pred = 0.0;
1521                    for (k, &beta_k) in beta.iter().enumerate() {
1522                        if k != j {
1523                            pred += x_centered.get(i, k) * beta_k;
1524                        }
1525                    }
1526                    let residual = y_centered[i] - pred;
1527                    rho += x_centered.get(i, j) * residual;
1528                }
1529
1530                // Update with soft-thresholding (L1) and shrinkage (L2)
1531                let old_beta = beta[j];
1532                let denom = col_norms_sq[j] + l2_penalty;
1533                beta[j] = Lasso::soft_threshold(rho, l1_penalty) / denom;
1534
1535                let change = (beta[j] - old_beta).abs();
1536                if change > max_change {
1537                    max_change = change;
1538                }
1539            }
1540
1541            if max_change < self.tol {
1542                break;
1543            }
1544        }
1545
1546        // Set intercept
1547        if self.fit_intercept {
1548            let mut intercept = y_mean;
1549            let mut x_mean = vec![0.0; n_features];
1550            for j in 0..n_features {
1551                for i in 0..n_samples {
1552                    x_mean[j] += x.get(i, j);
1553                }
1554                x_mean[j] /= n_samples as f32;
1555                intercept -= beta[j] * x_mean[j];
1556            }
1557            self.intercept = intercept;
1558        } else {
1559            self.intercept = 0.0;
1560        }
1561
1562        self.coefficients = Some(Vector::from_vec(beta));
1563        Ok(())
1564    }
1565
1566    /// Predicts target values for input data.
1567    ///
1568    /// # Panics
1569    ///
1570    /// Panics if model is not fitted.
1571    fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
1572        let coefficients = self
1573            .coefficients
1574            .as_ref()
1575            .expect("Model not fitted. Call fit() first.");
1576
1577        let result = x
1578            .matvec(coefficients)
1579            .expect("Matrix dimensions don't match coefficients");
1580
1581        result.add_scalar(self.intercept)
1582    }
1583
1584    /// Computes the R² score.
1585    fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
1586        let y_pred = self.predict(x);
1587        r_squared(&y_pred, y)
1588    }
1589}
1590
1591#[cfg(test)]
1592mod tests {
1593    use super::*;
1594
1595    #[test]
1596    fn test_new() {
1597        let model = LinearRegression::new();
1598        assert!(!model.is_fitted());
1599        assert!(model.fit_intercept);
1600    }
1601
1602    #[test]
1603    fn test_simple_regression() {
1604        // y = 2x + 1
1605        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
1606            .expect("Valid matrix dimensions for test");
1607        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
1608
1609        let mut model = LinearRegression::new();
1610        model
1611            .fit(&x, &y)
1612            .expect("Fit should succeed with valid test data");
1613
1614        assert!(model.is_fitted());
1615
1616        // Check coefficients
1617        let coef = model.coefficients();
1618        assert!((coef[0] - 2.0).abs() < 1e-4);
1619        assert!((model.intercept() - 1.0).abs() < 1e-4);
1620
1621        // Check predictions
1622        let predictions = model.predict(&x);
1623        for i in 0..4 {
1624            assert!((predictions[i] - y[i]).abs() < 1e-4);
1625        }
1626
1627        // Check R²
1628        let r2 = model.score(&x, &y);
1629        assert!((r2 - 1.0).abs() < 1e-4);
1630    }
1631
1632    #[test]
1633    fn test_multivariate_regression() {
1634        // y = 1 + 2*x1 + 3*x2
1635        let x = Matrix::from_vec(4, 2, vec![1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0])
1636            .expect("Valid matrix dimensions for test");
1637        let y = Vector::from_slice(&[6.0, 8.0, 9.0, 11.0]);
1638
1639        let mut model = LinearRegression::new();
1640        model
1641            .fit(&x, &y)
1642            .expect("Fit should succeed with valid test data");
1643
1644        let coef = model.coefficients();
1645        assert!((coef[0] - 2.0).abs() < 1e-4);
1646        assert!((coef[1] - 3.0).abs() < 1e-4);
1647        assert!((model.intercept() - 1.0).abs() < 1e-4);
1648
1649        let r2 = model.score(&x, &y);
1650        assert!((r2 - 1.0).abs() < 1e-4);
1651    }
1652
1653    #[test]
1654    fn test_no_intercept() {
1655        // y = 2x (no intercept)
1656        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
1657            .expect("Valid matrix dimensions for test");
1658        let y = Vector::from_slice(&[2.0, 4.0, 6.0, 8.0]);
1659
1660        let mut model = LinearRegression::new().with_intercept(false);
1661        model
1662            .fit(&x, &y)
1663            .expect("Fit should succeed with valid test data");
1664
1665        let coef = model.coefficients();
1666        assert!((coef[0] - 2.0).abs() < 1e-4);
1667        assert!((model.intercept() - 0.0).abs() < 1e-4);
1668    }
1669
1670    #[test]
1671    fn test_predict_new_data() {
1672        // y = x + 1
1673        let x_train =
1674            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1675        let y_train = Vector::from_slice(&[2.0, 3.0, 4.0]);
1676
1677        let mut model = LinearRegression::new();
1678        model
1679            .fit(&x_train, &y_train)
1680            .expect("Fit should succeed with valid test data");
1681
1682        let x_test =
1683            Matrix::from_vec(2, 1, vec![4.0, 5.0]).expect("Valid matrix dimensions for test");
1684        let predictions = model.predict(&x_test);
1685
1686        assert!((predictions[0] - 5.0).abs() < 1e-4);
1687        assert!((predictions[1] - 6.0).abs() < 1e-4);
1688    }
1689
1690    #[test]
1691    fn test_dimension_mismatch_error() {
1692        let x = Matrix::from_vec(3, 2, vec![1.0; 6]).expect("Valid matrix dimensions for test");
1693        let y = Vector::from_slice(&[1.0, 2.0]); // Wrong length
1694
1695        let mut model = LinearRegression::new();
1696        let result = model.fit(&x, &y);
1697        assert!(result.is_err());
1698    }
1699
1700    #[test]
1701    fn test_empty_data_error() {
1702        let x = Matrix::from_vec(0, 2, vec![]).expect("Valid matrix dimensions for test");
1703        let y = Vector::from_vec(vec![]);
1704
1705        let mut model = LinearRegression::new();
1706        let result = model.fit(&x, &y);
1707        assert!(result.is_err());
1708    }
1709
1710    #[test]
1711    fn test_with_noise() {
1712        // y ≈ 2x + 1 with some noise
1713        let x = Matrix::from_vec(5, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0])
1714            .expect("Valid matrix dimensions for test");
1715        let y = Vector::from_slice(&[3.1, 4.9, 7.2, 8.8, 11.1]);
1716
1717        let mut model = LinearRegression::new();
1718        model
1719            .fit(&x, &y)
1720            .expect("Fit should succeed with valid test data");
1721
1722        // Should still get approximately correct coefficients
1723        let coef = model.coefficients();
1724        assert!((coef[0] - 2.0).abs() < 0.2);
1725        assert!((model.intercept() - 1.0).abs() < 0.5);
1726
1727        // R² should be high but not perfect
1728        let r2 = model.score(&x, &y);
1729        assert!(r2 > 0.95);
1730        assert!(r2 < 1.0);
1731    }
1732
1733    #[test]
1734    fn test_default() {
1735        let model = LinearRegression::default();
1736        assert!(!model.is_fitted());
1737    }
1738
1739    #[test]
1740    fn test_clone() {
1741        let x =
1742            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1743        let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
1744
1745        let mut model = LinearRegression::new();
1746        model
1747            .fit(&x, &y)
1748            .expect("Fit should succeed with valid test data");
1749
1750        let cloned = model.clone();
1751        assert!(cloned.is_fitted());
1752        assert!((cloned.intercept() - model.intercept()).abs() < 1e-6);
1753    }
1754
1755    #[test]
1756    fn test_score_range() {
1757        // R² should be between negative infinity and 1
1758        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
1759            .expect("Valid matrix dimensions for test");
1760        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
1761
1762        let mut model = LinearRegression::new();
1763        model
1764            .fit(&x, &y)
1765            .expect("Fit should succeed with valid test data");
1766
1767        let r2 = model.score(&x, &y);
1768        assert!(r2 <= 1.0);
1769    }
1770
1771    #[test]
1772    fn test_prediction_invariant() {
1773        // Property: predict(fit(X, y), X) should approximate y
1774        // Use non-collinear data
1775        let x = Matrix::from_vec(5, 2, vec![1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 4.0, 5.0, 5.0, 4.0])
1776            .expect("Valid matrix dimensions for test");
1777        // y = 2*x1 + 3*x2 + 1
1778        let y = Vector::from_slice(&[6.0, 14.0, 13.0, 24.0, 23.0]);
1779
1780        let mut model = LinearRegression::new();
1781        model
1782            .fit(&x, &y)
1783            .expect("Fit should succeed with valid test data");
1784
1785        let predictions = model.predict(&x);
1786
1787        for i in 0..y.len() {
1788            assert!((predictions[i] - y[i]).abs() < 1e-3);
1789        }
1790    }
1791
1792    #[test]
1793    fn test_coefficients_length_invariant() {
1794        // Property: coefficients.len() == n_features
1795        // Use well-conditioned data with independent columns
1796        let x = Matrix::from_vec(
1797            6,
1798            3,
1799            vec![
1800                1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0,
1801                0.0, 1.0,
1802            ],
1803        )
1804        .expect("Valid matrix dimensions for test");
1805        let y = Vector::from_slice(&[1.0, 2.0, 3.0, 3.0, 5.0, 4.0]);
1806
1807        let mut model = LinearRegression::new();
1808        model
1809            .fit(&x, &y)
1810            .expect("Fit should succeed with valid test data");
1811
1812        assert_eq!(model.coefficients().len(), 3);
1813    }
1814
1815    #[test]
1816    fn test_larger_dataset() {
1817        // Test with more samples
1818        let n = 100;
1819        let mut x_data = Vec::with_capacity(n);
1820        let mut y_data = Vec::with_capacity(n);
1821
1822        for i in 0..n {
1823            let x_val = i as f32;
1824            x_data.push(x_val);
1825            y_data.push(2.0 * x_val + 3.0); // y = 2x + 3
1826        }
1827
1828        let x = Matrix::from_vec(n, 1, x_data).expect("Valid matrix dimensions for test");
1829        let y = Vector::from_vec(y_data);
1830
1831        let mut model = LinearRegression::new();
1832        model
1833            .fit(&x, &y)
1834            .expect("Fit should succeed with valid test data");
1835
1836        let coef = model.coefficients();
1837        assert!((coef[0] - 2.0).abs() < 1e-3);
1838        assert!((model.intercept() - 3.0).abs() < 1e-3);
1839    }
1840
1841    #[test]
1842    fn test_single_sample_single_feature() {
1843        // Edge case: minimum viable data
1844        let x = Matrix::from_vec(2, 1, vec![1.0, 2.0]).expect("Valid matrix dimensions for test");
1845        let y = Vector::from_slice(&[3.0, 5.0]);
1846
1847        let mut model = LinearRegression::new();
1848        model
1849            .fit(&x, &y)
1850            .expect("Fit should succeed with valid test data");
1851
1852        // y = 2x + 1
1853        let coef = model.coefficients();
1854        assert!((coef[0] - 2.0).abs() < 1e-4);
1855        assert!((model.intercept() - 1.0).abs() < 1e-4);
1856    }
1857
1858    #[test]
1859    fn test_negative_values() {
1860        // Test with negative coefficients and values
1861        let x = Matrix::from_vec(4, 1, vec![-2.0, -1.0, 0.0, 1.0])
1862            .expect("Valid matrix dimensions for test");
1863        let y = Vector::from_slice(&[5.0, 3.0, 1.0, -1.0]); // y = -2x + 1
1864
1865        let mut model = LinearRegression::new();
1866        model
1867            .fit(&x, &y)
1868            .expect("Fit should succeed with valid test data");
1869
1870        let coef = model.coefficients();
1871        assert!((coef[0] - (-2.0)).abs() < 1e-4);
1872        assert!((model.intercept() - 1.0).abs() < 1e-4);
1873    }
1874
1875    #[test]
1876    fn test_large_values() {
1877        // Test numerical stability with large values
1878        let x = Matrix::from_vec(3, 1, vec![1000.0, 2000.0, 3000.0])
1879            .expect("Valid matrix dimensions for test");
1880        let y = Vector::from_slice(&[2001.0, 4001.0, 6001.0]); // y = 2x + 1
1881
1882        let mut model = LinearRegression::new();
1883        model
1884            .fit(&x, &y)
1885            .expect("Fit should succeed with valid test data");
1886
1887        let coef = model.coefficients();
1888        assert!((coef[0] - 2.0).abs() < 1e-2);
1889        assert!((model.intercept() - 1.0).abs() < 10.0); // Larger tolerance for large values
1890    }
1891
1892    #[test]
1893    fn test_small_values() {
1894        // Test with small values
1895        let x = Matrix::from_vec(3, 1, vec![0.001, 0.002, 0.003])
1896            .expect("Valid matrix dimensions for test");
1897        let y = Vector::from_slice(&[0.003, 0.005, 0.007]); // y = 2x + 0.001
1898
1899        let mut model = LinearRegression::new();
1900        model
1901            .fit(&x, &y)
1902            .expect("Fit should succeed with valid test data");
1903
1904        let coef = model.coefficients();
1905        assert!((coef[0] - 2.0).abs() < 1e-2);
1906    }
1907
1908    #[test]
1909    fn test_zero_intercept_data() {
1910        // Data that should produce zero intercept
1911        let x =
1912            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1913        let y = Vector::from_slice(&[2.0, 4.0, 6.0]); // y = 2x
1914
1915        let mut model = LinearRegression::new();
1916        model
1917            .fit(&x, &y)
1918            .expect("Fit should succeed with valid test data");
1919
1920        let coef = model.coefficients();
1921        assert!((coef[0] - 2.0).abs() < 1e-4);
1922        assert!(model.intercept().abs() < 1e-4);
1923    }
1924
1925    #[test]
1926    fn test_constant_target() {
1927        // All y values are the same
1928        let x =
1929            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1930        let y = Vector::from_slice(&[5.0, 5.0, 5.0]);
1931
1932        let mut model = LinearRegression::new();
1933        model
1934            .fit(&x, &y)
1935            .expect("Fit should succeed with valid test data");
1936
1937        // Coefficient should be ~0, intercept should be ~5
1938        let coef = model.coefficients();
1939        assert!(coef[0].abs() < 1e-4);
1940        assert!((model.intercept() - 5.0).abs() < 1e-4);
1941    }
1942
1943    #[test]
1944    fn test_r2_score_bounds() {
1945        // R² should be in reasonable range for good fit
1946        let x = Matrix::from_vec(5, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0])
1947            .expect("Valid matrix dimensions for test");
1948        let y = Vector::from_slice(&[2.1, 3.9, 6.1, 7.9, 10.1]);
1949
1950        let mut model = LinearRegression::new();
1951        model
1952            .fit(&x, &y)
1953            .expect("Fit should succeed with valid test data");
1954
1955        let r2 = model.score(&x, &y);
1956        assert!(r2 > 0.0);
1957        assert!(r2 <= 1.0);
1958    }
1959
1960    #[test]
1961    fn test_extrapolation() {
1962        // Test prediction outside training range
1963        let x_train =
1964            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1965        let y_train = Vector::from_slice(&[2.0, 4.0, 6.0]); // y = 2x
1966
1967        let mut model = LinearRegression::new();
1968        model
1969            .fit(&x_train, &y_train)
1970            .expect("Fit should succeed with valid test data");
1971
1972        // Predict at x = 10 (extrapolation)
1973        let x_test = Matrix::from_vec(1, 1, vec![10.0]).expect("Valid matrix dimensions for test");
1974        let predictions = model.predict(&x_test);
1975
1976        assert!((predictions[0] - 20.0).abs() < 1e-4);
1977    }
1978
1979    #[test]
1980    fn test_underdetermined_system_with_intercept() {
1981        // n_samples < n_features + 1 (underdetermined with intercept)
1982        // 3 samples, 5 features, fit_intercept=true means we need 6 parameters
1983        let x = Matrix::from_vec(
1984            3,
1985            5,
1986            vec![
1987                1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0, 7.0,
1988            ],
1989        )
1990        .expect("Valid matrix dimensions for test");
1991        let y = Vector::from_vec(vec![10.0, 20.0, 30.0]);
1992
1993        let mut model = LinearRegression::new();
1994        let result = model.fit(&x, &y);
1995
1996        assert!(result.is_err());
1997        let error_msg = result.expect_err("Should fail when underdetermined system with intercept");
1998        let error_str = error_msg.to_string();
1999        // Should mention samples, features, and suggest solutions
2000        assert!(
2001            error_str.contains("samples") || error_str.contains("features"),
2002            "Error message should mention samples or features: {error_str}"
2003        );
2004    }
2005
2006    #[test]
2007    fn test_underdetermined_system_without_intercept() {
2008        // n_samples < n_features (underdetermined without intercept)
2009        // 3 samples, 5 features, fit_intercept=false
2010        let x = Matrix::from_vec(
2011            3,
2012            5,
2013            vec![
2014                1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0, 7.0,
2015            ],
2016        )
2017        .expect("Valid matrix dimensions for test");
2018        let y = Vector::from_vec(vec![10.0, 20.0, 30.0]);
2019
2020        let mut model = LinearRegression::new().with_intercept(false);
2021        let result = model.fit(&x, &y);
2022
2023        assert!(result.is_err());
2024        let error_msg =
2025            result.expect_err("Should fail when underdetermined system without intercept");
2026        let error_str = error_msg.to_string();
2027        assert!(
2028            error_str.contains("samples") || error_str.contains("features"),
2029            "Error message should be helpful: {error_str}"
2030        );
2031    }
2032
2033    #[test]
2034    fn test_exactly_determined_system() {
2035        // n_samples == n_features + 1 (exactly determined with intercept)
2036        // 4 samples, 3 features, fit_intercept=true means 4 parameters
2037        let x = Matrix::from_vec(
2038            4,
2039            3,
2040            vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
2041        )
2042        .expect("Valid matrix dimensions for test");
2043        let y = Vector::from_vec(vec![1.0, 2.0, 3.0, 6.0]);
2044
2045        let mut model = LinearRegression::new();
2046        let result = model.fit(&x, &y);
2047
2048        // This should succeed (exactly determined)
2049        assert!(result.is_ok(), "Exactly determined system should work");
2050    }
2051
2052    #[test]
2053    fn test_save_load_binary() {
2054        use std::fs;
2055        use std::path::Path;
2056
2057        // Train a model
2058        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2059            .expect("Valid matrix dimensions for test");
2060        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]); // y = 2x + 1
2061
2062        let mut model = LinearRegression::new();
2063        model
2064            .fit(&x, &y)
2065            .expect("Fit should succeed with valid test data");
2066
2067        // Save to file
2068        let path = Path::new("/tmp/test_linear_regression.bin");
2069        model.save(path).expect("Failed to save model");
2070
2071        // Load from file
2072        let loaded_model = LinearRegression::load(path).expect("Failed to load model");
2073
2074        // Verify loaded model matches original
2075        let original_pred = model.predict(&x);
2076        let loaded_pred = loaded_model.predict(&x);
2077
2078        for i in 0..original_pred.len() {
2079            assert!(
2080                (original_pred[i] - loaded_pred[i]).abs() < 1e-6,
2081                "Loaded model predictions don't match original"
2082            );
2083        }
2084
2085        // Verify coefficients and intercept match
2086        assert_eq!(
2087            model.coefficients().len(),
2088            loaded_model.coefficients().len()
2089        );
2090        for i in 0..model.coefficients().len() {
2091            assert!((model.coefficients()[i] - loaded_model.coefficients()[i]).abs() < 1e-6);
2092        }
2093        assert!((model.intercept() - loaded_model.intercept()).abs() < 1e-6);
2094
2095        // Cleanup
2096        fs::remove_file(path).ok();
2097    }
2098
2099    #[test]
2100    fn test_with_intercept_returns_self() {
2101        // Test that with_intercept returns the modified self, not a default
2102        // This catches the mutation: with_intercept -> Default::default()
2103        let model = LinearRegression::new().with_intercept(false);
2104
2105        // If mutation returns Default::default(), fit_intercept would be true
2106        // Since new() sets fit_intercept = true by default
2107
2108        // We need to verify the model actually has fit_intercept = false
2109        // by checking the fitted behavior
2110        let x =
2111            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2112        let y = Vector::from_slice(&[2.0, 4.0, 6.0]); // y = 2x
2113
2114        let mut model = model;
2115        model
2116            .fit(&x, &y)
2117            .expect("Fit should succeed with valid test data");
2118
2119        // Without intercept, the model should pass through origin
2120        // Predicting x=0 should give y=0 (no intercept term)
2121        let x_zero = Matrix::from_vec(1, 1, vec![0.0]).expect("Valid matrix dimensions for test");
2122        let pred = model.predict(&x_zero);
2123
2124        assert!(
2125            pred[0].abs() < 1e-6,
2126            "Model without intercept should predict 0 at x=0, got {}",
2127            pred[0]
2128        );
2129    }
2130
2131    #[test]
2132    fn test_with_intercept_builder_chain() {
2133        // Test that builder pattern works correctly
2134        // with_intercept(false) followed by fitting should not have intercept
2135        let x = Matrix::from_vec(2, 1, vec![1.0, 2.0]).expect("Valid matrix dimensions for test");
2136        let y = Vector::from_slice(&[3.0, 5.0]); // y = 2x + 1
2137
2138        // Model with intercept
2139        let mut with_int = LinearRegression::new().with_intercept(true);
2140        with_int
2141            .fit(&x, &y)
2142            .expect("Fit should succeed with valid test data");
2143
2144        // Model without intercept
2145        let mut without_int = LinearRegression::new().with_intercept(false);
2146        without_int
2147            .fit(&x, &y)
2148            .expect("Fit should succeed with valid test data");
2149
2150        // The intercept should be different
2151        // With intercept: should have non-zero intercept for this data
2152        // Without intercept: intercept is always 0
2153        assert!(
2154            with_int.intercept().abs() > 0.1,
2155            "Model with intercept should have non-zero intercept"
2156        );
2157        assert!(
2158            without_int.intercept().abs() < 1e-6,
2159            "Model without intercept should have zero intercept, got {}",
2160            without_int.intercept()
2161        );
2162    }
2163
2164    // Ridge regression tests
2165    #[test]
2166    fn test_ridge_new() {
2167        let model = Ridge::new(1.0);
2168        assert!(!model.is_fitted());
2169        assert!((model.alpha() - 1.0).abs() < 1e-6);
2170    }
2171
2172    #[test]
2173    fn test_ridge_simple_regression() {
2174        // y = 2x + 1
2175        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2176            .expect("Valid matrix dimensions for test");
2177        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2178
2179        let mut model = Ridge::new(0.0); // No regularization = OLS
2180        model
2181            .fit(&x, &y)
2182            .expect("Fit should succeed with valid test data");
2183
2184        assert!(model.is_fitted());
2185
2186        // Check predictions are close (might not be perfect due to regularization)
2187        let r2 = model.score(&x, &y);
2188        assert!(r2 > 0.99);
2189    }
2190
2191    #[test]
2192    fn test_ridge_regularization_shrinks_coefficients() {
2193        // Test that higher alpha shrinks coefficients
2194        let x = Matrix::from_vec(5, 2, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0])
2195            .expect("Valid matrix dimensions for test");
2196        let y = Vector::from_slice(&[4.0, 8.0, 12.0, 16.0, 20.0]);
2197
2198        // Low regularization
2199        let mut low_reg = Ridge::new(0.01);
2200        low_reg
2201            .fit(&x, &y)
2202            .expect("Fit should succeed with valid test data");
2203
2204        // High regularization
2205        let mut high_reg = Ridge::new(100.0);
2206        high_reg
2207            .fit(&x, &y)
2208            .expect("Fit should succeed with valid test data");
2209
2210        // Higher regularization should produce smaller coefficient magnitudes
2211        let low_coef = low_reg.coefficients();
2212        let high_coef = high_reg.coefficients();
2213        let low_norm: f32 = (0..low_coef.len()).map(|i| low_coef[i] * low_coef[i]).sum();
2214        let high_norm: f32 = (0..high_coef.len())
2215            .map(|i| high_coef[i] * high_coef[i])
2216            .sum();
2217
2218        assert!(
2219            high_norm < low_norm,
2220            "High regularization should shrink coefficients: {high_norm} < {low_norm}"
2221        );
2222    }
2223
2224    #[test]
2225    fn test_ridge_multivariate() {
2226        // y = 1 + 2*x1 + 3*x2
2227        let x = Matrix::from_vec(5, 2, vec![1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0])
2228            .expect("Valid matrix dimensions for test");
2229        let y = Vector::from_slice(&[6.0, 8.0, 9.0, 11.0, 16.0]);
2230
2231        let mut model = Ridge::new(0.1);
2232        model
2233            .fit(&x, &y)
2234            .expect("Fit should succeed with valid test data");
2235
2236        let r2 = model.score(&x, &y);
2237        assert!(r2 > 0.95);
2238    }
2239
2240    #[test]
2241    fn test_ridge_no_intercept() {
2242        // y = 2x
2243        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2244            .expect("Valid matrix dimensions for test");
2245        let y = Vector::from_slice(&[2.0, 4.0, 6.0, 8.0]);
2246
2247        let mut model = Ridge::new(0.1).with_intercept(false);
2248        model
2249            .fit(&x, &y)
2250            .expect("Fit should succeed with valid test data");
2251
2252        assert!((model.intercept() - 0.0).abs() < 1e-6);
2253    }
2254
2255    #[test]
2256    fn test_ridge_dimension_mismatch_error() {
2257        let x = Matrix::from_vec(3, 2, vec![1.0; 6]).expect("Valid matrix dimensions for test");
2258        let y = Vector::from_slice(&[1.0, 2.0]); // Wrong length
2259
2260        let mut model = Ridge::new(1.0);
2261        let result = model.fit(&x, &y);
2262        assert!(result.is_err());
2263    }
2264
2265    #[test]
2266    fn test_ridge_empty_data_error() {
2267        let x = Matrix::from_vec(0, 2, vec![]).expect("Valid matrix dimensions for test");
2268        let y = Vector::from_vec(vec![]);
2269
2270        let mut model = Ridge::new(1.0);
2271        let result = model.fit(&x, &y);
2272        assert!(result.is_err());
2273    }
2274
2275    #[test]
2276    fn test_ridge_underdetermined_system() {
2277        // Ridge can handle underdetermined systems due to regularization
2278        // 3 samples, 5 features
2279        let x = Matrix::from_vec(
2280            3,
2281            5,
2282            vec![
2283                1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0, 7.0,
2284            ],
2285        )
2286        .expect("Valid matrix dimensions for test");
2287        let y = Vector::from_vec(vec![10.0, 20.0, 30.0]);
2288
2289        // With sufficient regularization, this should work
2290        let mut model = Ridge::new(10.0);
2291        let result = model.fit(&x, &y);
2292        assert!(
2293            result.is_ok(),
2294            "Ridge should handle underdetermined systems"
2295        );
2296    }
2297
2298    #[test]
2299    fn test_ridge_clone() {
2300        let x =
2301            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2302        let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2303
2304        let mut model = Ridge::new(0.5);
2305        model
2306            .fit(&x, &y)
2307            .expect("Fit should succeed with valid test data");
2308
2309        let cloned = model.clone();
2310        assert!(cloned.is_fitted());
2311        assert!((cloned.alpha() - model.alpha()).abs() < 1e-6);
2312        assert!((cloned.intercept() - model.intercept()).abs() < 1e-6);
2313    }
2314
2315    #[test]
2316    fn test_ridge_alpha_zero_equals_ols() {
2317        // Ridge with alpha=0 should give same results as OLS
2318        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2319            .expect("Valid matrix dimensions for test");
2320        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2321
2322        let mut ridge = Ridge::new(0.0);
2323        ridge
2324            .fit(&x, &y)
2325            .expect("Fit should succeed with valid test data");
2326
2327        let mut ols = LinearRegression::new();
2328        ols.fit(&x, &y)
2329            .expect("Fit should succeed with valid test data");
2330
2331        // Coefficients should be nearly identical
2332        assert!(
2333            (ridge.coefficients()[0] - ols.coefficients()[0]).abs() < 1e-4,
2334            "Ridge with alpha=0 should equal OLS"
2335        );
2336        assert!((ridge.intercept() - ols.intercept()).abs() < 1e-4);
2337    }
2338
2339    #[test]
2340    fn test_ridge_save_load() {
2341        use std::fs;
2342        use std::path::Path;
2343
2344        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2345            .expect("Valid matrix dimensions for test");
2346        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2347
2348        let mut model = Ridge::new(0.5);
2349        model
2350            .fit(&x, &y)
2351            .expect("Fit should succeed with valid test data");
2352
2353        let path = Path::new("/tmp/test_ridge.bin");
2354        model.save(path).expect("Failed to save model");
2355
2356        let loaded = Ridge::load(path).expect("Failed to load model");
2357
2358        // Verify loaded model matches original
2359        assert!((loaded.alpha() - model.alpha()).abs() < 1e-6);
2360        let original_pred = model.predict(&x);
2361        let loaded_pred = loaded.predict(&x);
2362
2363        for i in 0..original_pred.len() {
2364            assert!((original_pred[i] - loaded_pred[i]).abs() < 1e-6);
2365        }
2366
2367        fs::remove_file(path).ok();
2368    }
2369
2370    #[test]
2371    fn test_ridge_with_intercept_builder() {
2372        let model = Ridge::new(1.0).with_intercept(false);
2373
2374        let x =
2375            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2376        let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2377
2378        let mut model = model;
2379        model
2380            .fit(&x, &y)
2381            .expect("Fit should succeed with valid test data");
2382
2383        // Without intercept, predicting at x=0 should give 0
2384        let x_zero = Matrix::from_vec(1, 1, vec![0.0]).expect("Valid matrix dimensions for test");
2385        let pred = model.predict(&x_zero);
2386
2387        assert!(
2388            pred[0].abs() < 1e-6,
2389            "Ridge without intercept should predict 0 at x=0"
2390        );
2391    }
2392
2393    #[test]
2394    fn test_ridge_coefficients_length() {
2395        let x = Matrix::from_vec(5, 3, vec![1.0; 15]).expect("Valid matrix dimensions for test");
2396        let y = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
2397
2398        let mut model = Ridge::new(1.0);
2399        model
2400            .fit(&x, &y)
2401            .expect("Fit should succeed with valid test data");
2402
2403        assert_eq!(model.coefficients().len(), 3);
2404    }
2405
2406    // Lasso regression tests
2407    #[test]
2408    fn test_lasso_new() {
2409        let model = Lasso::new(1.0);
2410        assert!(!model.is_fitted());
2411        assert!((model.alpha() - 1.0).abs() < 1e-6);
2412    }
2413
2414    #[test]
2415    fn test_lasso_simple_regression() {
2416        // y = 2x + 1
2417        let x = Matrix::from_vec(5, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0])
2418            .expect("Valid matrix dimensions for test");
2419        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0]);
2420
2421        let mut model = Lasso::new(0.01); // Small regularization
2422        model
2423            .fit(&x, &y)
2424            .expect("Fit should succeed with valid test data");
2425
2426        assert!(model.is_fitted());
2427
2428        let r2 = model.score(&x, &y);
2429        assert!(r2 > 0.98, "R² should be > 0.98, got {r2}");
2430    }
2431
2432    #[test]
2433    fn test_lasso_produces_sparsity() {
2434        // Test that Lasso with high alpha produces sparse coefficients
2435        // Create data where only first feature matters: y = x1
2436        let x = Matrix::from_vec(
2437            6,
2438            3,
2439            vec![
2440                1.0, 0.1, 0.2, 2.0, 0.2, 0.1, 3.0, 0.1, 0.3, 4.0, 0.3, 0.1, 5.0, 0.2, 0.2, 6.0,
2441                0.1, 0.1,
2442            ],
2443        )
2444        .expect("Valid matrix dimensions for test");
2445        let y = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2446
2447        let mut model = Lasso::new(1.0); // High regularization
2448        model
2449            .fit(&x, &y)
2450            .expect("Fit should succeed with valid test data");
2451
2452        // Count non-zero coefficients
2453        let coef = model.coefficients();
2454        let mut non_zero = 0;
2455        for i in 0..coef.len() {
2456            if coef[i].abs() > 1e-4 {
2457                non_zero += 1;
2458            }
2459        }
2460
2461        // With high alpha, some coefficients should be zeroed out
2462        assert!(
2463            non_zero < coef.len(),
2464            "Lasso should produce sparse solution, got {} non-zero out of {}",
2465            non_zero,
2466            coef.len()
2467        );
2468    }
2469
2470    #[test]
2471    fn test_lasso_multivariate() {
2472        // y = 1 + 2*x1 + 3*x2
2473        let x = Matrix::from_vec(
2474            6,
2475            2,
2476            vec![1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0],
2477        )
2478        .expect("Valid matrix dimensions for test");
2479        let y = Vector::from_slice(&[6.0, 8.0, 9.0, 11.0, 16.0, 21.0]);
2480
2481        let mut model = Lasso::new(0.01);
2482        model
2483            .fit(&x, &y)
2484            .expect("Fit should succeed with valid test data");
2485
2486        let r2 = model.score(&x, &y);
2487        assert!(r2 > 0.95, "R² should be > 0.95, got {r2}");
2488    }
2489
2490    #[test]
2491    fn test_lasso_no_intercept() {
2492        // y = 2x
2493        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2494            .expect("Valid matrix dimensions for test");
2495        let y = Vector::from_slice(&[2.0, 4.0, 6.0, 8.0]);
2496
2497        let mut model = Lasso::new(0.01).with_intercept(false);
2498        model
2499            .fit(&x, &y)
2500            .expect("Fit should succeed with valid test data");
2501
2502        assert!((model.intercept() - 0.0).abs() < 1e-6);
2503    }
2504
2505    #[test]
2506    fn test_lasso_dimension_mismatch_error() {
2507        let x = Matrix::from_vec(3, 2, vec![1.0; 6]).expect("Valid matrix dimensions for test");
2508        let y = Vector::from_slice(&[1.0, 2.0]); // Wrong length
2509
2510        let mut model = Lasso::new(1.0);
2511        let result = model.fit(&x, &y);
2512        assert!(result.is_err());
2513    }
2514
2515    #[test]
2516    fn test_lasso_empty_data_error() {
2517        let x = Matrix::from_vec(0, 2, vec![]).expect("Valid matrix dimensions for test");
2518        let y = Vector::from_vec(vec![]);
2519
2520        let mut model = Lasso::new(1.0);
2521        let result = model.fit(&x, &y);
2522        assert!(result.is_err());
2523    }
2524
2525    #[test]
2526    fn test_lasso_clone() {
2527        let x =
2528            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2529        let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2530
2531        let mut model = Lasso::new(0.5);
2532        model
2533            .fit(&x, &y)
2534            .expect("Fit should succeed with valid test data");
2535
2536        let cloned = model.clone();
2537        assert!(cloned.is_fitted());
2538        assert!((cloned.alpha() - model.alpha()).abs() < 1e-6);
2539        assert!((cloned.intercept() - model.intercept()).abs() < 1e-6);
2540    }
2541
2542    #[test]
2543    fn test_lasso_save_load() {
2544        use std::fs;
2545        use std::path::Path;
2546
2547        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2548            .expect("Valid matrix dimensions for test");
2549        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2550
2551        let mut model = Lasso::new(0.1);
2552        model
2553            .fit(&x, &y)
2554            .expect("Fit should succeed with valid test data");
2555
2556        let path = Path::new("/tmp/test_lasso.bin");
2557        model.save(path).expect("Failed to save model");
2558
2559        let loaded = Lasso::load(path).expect("Failed to load model");
2560
2561        assert!((loaded.alpha() - model.alpha()).abs() < 1e-6);
2562        let original_pred = model.predict(&x);
2563        let loaded_pred = loaded.predict(&x);
2564
2565        for i in 0..original_pred.len() {
2566            assert!((original_pred[i] - loaded_pred[i]).abs() < 1e-6);
2567        }
2568
2569        fs::remove_file(path).ok();
2570    }
2571
2572    #[test]
2573    fn test_lasso_with_intercept_builder() {
2574        let model = Lasso::new(1.0).with_intercept(false);
2575
2576        let x =
2577            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2578        let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2579
2580        let mut model = model;
2581        model
2582            .fit(&x, &y)
2583            .expect("Fit should succeed with valid test data");
2584
2585        let x_zero = Matrix::from_vec(1, 1, vec![0.0]).expect("Valid matrix dimensions for test");
2586        let pred = model.predict(&x_zero);
2587
2588        assert!(
2589            pred[0].abs() < 1e-6,
2590            "Lasso without intercept should predict 0 at x=0"
2591        );
2592    }
2593
2594    #[test]
2595    fn test_lasso_coefficients_length() {
2596        let x = Matrix::from_vec(5, 3, vec![1.0; 15]).expect("Valid matrix dimensions for test");
2597        let y = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
2598
2599        let mut model = Lasso::new(0.1);
2600        model
2601            .fit(&x, &y)
2602            .expect("Fit should succeed with valid test data");
2603
2604        assert_eq!(model.coefficients().len(), 3);
2605    }
2606
2607    #[test]
2608    fn test_lasso_with_max_iter() {
2609        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2610            .expect("Valid matrix dimensions for test");
2611        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2612
2613        let mut model = Lasso::new(0.1).with_max_iter(100);
2614        model
2615            .fit(&x, &y)
2616            .expect("Fit should succeed with valid test data");
2617
2618        assert!(model.is_fitted());
2619    }
2620
2621    #[test]
2622    fn test_lasso_with_tol() {
2623        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2624            .expect("Valid matrix dimensions for test");
2625        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2626
2627        let mut model = Lasso::new(0.1).with_tol(1e-6);
2628        model
2629            .fit(&x, &y)
2630            .expect("Fit should succeed with valid test data");
2631
2632        assert!(model.is_fitted());
2633    }
2634
2635    #[test]
2636    fn test_lasso_soft_threshold() {
2637        // Test the soft-thresholding function
2638        assert!((Lasso::soft_threshold(5.0, 2.0) - 3.0).abs() < 1e-6);
2639        assert!((Lasso::soft_threshold(-5.0, 2.0) - (-3.0)).abs() < 1e-6);
2640        assert!((Lasso::soft_threshold(1.0, 2.0) - 0.0).abs() < 1e-6);
2641        assert!((Lasso::soft_threshold(-1.0, 2.0) - 0.0).abs() < 1e-6);
2642    }
2643
2644    // ==================== ElasticNet Tests ====================
2645
2646    #[test]
2647    fn test_elastic_net_new() {
2648        let model = ElasticNet::new(1.0, 0.5);
2649        assert!(!model.is_fitted());
2650        assert!((model.alpha() - 1.0).abs() < 1e-6);
2651        assert!((model.l1_ratio() - 0.5).abs() < 1e-6);
2652    }
2653
2654    #[test]
2655    fn test_elastic_net_simple() {
2656        // y = 2x + 1
2657        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2658            .expect("Valid matrix dimensions for test");
2659        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2660
2661        let mut model = ElasticNet::new(0.01, 0.5);
2662        model
2663            .fit(&x, &y)
2664            .expect("Fit should succeed with valid test data");
2665
2666        assert!(model.is_fitted());
2667
2668        // Should recover approximately y = 2x + 1
2669        let coef = model.coefficients();
2670        assert!((coef[0] - 2.0).abs() < 0.5); // Some regularization effect
2671        assert!((model.intercept() - 1.0).abs() < 1.0);
2672    }
2673
2674    #[test]
2675    fn test_elastic_net_multivariate() {
2676        // y = 2*x1 + 3*x2
2677        let x = Matrix::from_vec(4, 2, vec![1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0])
2678            .expect("Valid matrix dimensions for test");
2679        let y = Vector::from_slice(&[5.0, 7.0, 8.0, 10.0]);
2680
2681        let mut model = ElasticNet::new(0.01, 0.5);
2682        model
2683            .fit(&x, &y)
2684            .expect("Fit should succeed with valid test data");
2685
2686        let predictions = model.predict(&x);
2687        for i in 0..4 {
2688            assert!((predictions[i] - y[i]).abs() < 1.0);
2689        }
2690    }
2691
2692    #[test]
2693    fn test_elastic_net_l1_ratio_pure_l1() {
2694        // l1_ratio=1.0 should behave like Lasso
2695        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2696            .expect("Valid matrix dimensions for test");
2697        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2698
2699        let mut elastic = ElasticNet::new(0.1, 1.0);
2700        elastic
2701            .fit(&x, &y)
2702            .expect("Fit should succeed with valid test data");
2703
2704        let mut lasso = Lasso::new(0.1);
2705        lasso
2706            .fit(&x, &y)
2707            .expect("Fit should succeed with valid test data");
2708
2709        // Should have similar coefficients
2710        let elastic_coef = elastic.coefficients();
2711        let lasso_coef = lasso.coefficients();
2712        assert!((elastic_coef[0] - lasso_coef[0]).abs() < 0.1);
2713    }
2714
2715    #[test]
2716    fn test_elastic_net_l1_ratio_pure_l2() {
2717        // l1_ratio=0.0 should behave like Ridge
2718        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2719            .expect("Valid matrix dimensions for test");
2720        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2721
2722        let mut elastic = ElasticNet::new(0.1, 0.0);
2723        elastic
2724            .fit(&x, &y)
2725            .expect("Fit should succeed with valid test data");
2726
2727        let mut ridge = Ridge::new(0.1);
2728        ridge
2729            .fit(&x, &y)
2730            .expect("Fit should succeed with valid test data");
2731
2732        // Should have similar coefficients
2733        let elastic_coef = elastic.coefficients();
2734        let ridge_coef = ridge.coefficients();
2735        assert!((elastic_coef[0] - ridge_coef[0]).abs() < 0.5);
2736    }
2737
2738    #[test]
2739    fn test_elastic_net_dimension_mismatch() {
2740        let x = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2741            .expect("Valid matrix dimensions for test");
2742        let y = Vector::from_slice(&[1.0, 2.0]); // Wrong length
2743
2744        let mut model = ElasticNet::new(0.1, 0.5);
2745        let result = model.fit(&x, &y);
2746        assert!(result.is_err());
2747    }
2748
2749    #[test]
2750    fn test_elastic_net_empty_data() {
2751        let x = Matrix::from_vec(0, 2, vec![]).expect("Valid matrix dimensions for test");
2752        let y = Vector::from_vec(vec![]);
2753
2754        let mut model = ElasticNet::new(0.1, 0.5);
2755        let result = model.fit(&x, &y);
2756        assert!(result.is_err());
2757    }
2758
2759    #[test]
2760    #[should_panic(expected = "Model not fitted")]
2761    fn test_elastic_net_predict_not_fitted() {
2762        let model = ElasticNet::new(0.1, 0.5);
2763        let x = Matrix::from_vec(1, 1, vec![1.0]).expect("Valid matrix dimensions for test");
2764        let _ = model.predict(&x);
2765    }
2766
2767    #[test]
2768    fn test_elastic_net_score() {
2769        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2770            .expect("Valid matrix dimensions for test");
2771        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2772
2773        let mut model = ElasticNet::new(0.01, 0.5);
2774        model
2775            .fit(&x, &y)
2776            .expect("Fit should succeed with valid test data");
2777
2778        let r2 = model.score(&x, &y);
2779        assert!(r2 > 0.9); // Should fit well with small alpha
2780    }
2781
2782    #[test]
2783    fn test_elastic_net_clone() {
2784        let x =
2785            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2786        let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2787
2788        let mut model = ElasticNet::new(0.5, 0.5);
2789        model
2790            .fit(&x, &y)
2791            .expect("Fit should succeed with valid test data");
2792
2793        let cloned = model.clone();
2794        assert!(cloned.is_fitted());
2795        assert!((cloned.alpha() - model.alpha()).abs() < 1e-6);
2796        assert!((cloned.l1_ratio() - model.l1_ratio()).abs() < 1e-6);
2797        assert!((cloned.intercept() - model.intercept()).abs() < 1e-6);
2798    }
2799
2800    #[test]
2801    fn test_elastic_net_save_load() {
2802        use std::fs;
2803        use std::path::Path;
2804
2805        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2806            .expect("Valid matrix dimensions for test");
2807        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2808
2809        let mut model = ElasticNet::new(0.1, 0.5);
2810        model
2811            .fit(&x, &y)
2812            .expect("Fit should succeed with valid test data");
2813
2814        let path = Path::new("/tmp/test_elastic_net.bin");
2815        model.save(path).expect("Failed to save model");
2816
2817        let loaded = ElasticNet::load(path).expect("Failed to load model");
2818
2819        assert!((loaded.alpha() - model.alpha()).abs() < 1e-6);
2820        assert!((loaded.l1_ratio() - model.l1_ratio()).abs() < 1e-6);
2821        let original_pred = model.predict(&x);
2822        let loaded_pred = loaded.predict(&x);
2823
2824        for i in 0..original_pred.len() {
2825            assert!((original_pred[i] - loaded_pred[i]).abs() < 1e-6);
2826        }
2827
2828        fs::remove_file(path).ok();
2829    }
2830
2831    #[test]
2832    fn test_elastic_net_with_intercept_builder() {
2833        let model = ElasticNet::new(1.0, 0.5).with_intercept(false);
2834
2835        let x =
2836            Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2837        let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2838
2839        let mut model = model;
2840        model
2841            .fit(&x, &y)
2842            .expect("Fit should succeed with valid test data");
2843
2844        let x_zero = Matrix::from_vec(1, 1, vec![0.0]).expect("Valid matrix dimensions for test");
2845        let pred = model.predict(&x_zero);
2846        assert!((pred[0] - 0.0).abs() < 1e-6); // No intercept
2847    }
2848
2849    #[test]
2850    fn test_elastic_net_multivariate_coefficients() {
2851        let x = Matrix::from_vec(3, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
2852            .expect("Valid matrix dimensions for test");
2853        let y = Vector::from_slice(&[6.0, 15.0, 24.0]);
2854
2855        let mut model = ElasticNet::new(0.1, 0.5);
2856        model
2857            .fit(&x, &y)
2858            .expect("Fit should succeed with valid test data");
2859
2860        assert_eq!(model.coefficients().len(), 3);
2861    }
2862
2863    #[test]
2864    fn test_elastic_net_with_max_iter() {
2865        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2866            .expect("Valid matrix dimensions for test");
2867        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2868
2869        let mut model = ElasticNet::new(0.1, 0.5).with_max_iter(100);
2870        model
2871            .fit(&x, &y)
2872            .expect("Fit should succeed with valid test data");
2873
2874        assert!(model.is_fitted());
2875    }
2876
2877    #[test]
2878    fn test_elastic_net_with_tol() {
2879        let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2880            .expect("Valid matrix dimensions for test");
2881        let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2882
2883        let mut model = ElasticNet::new(0.1, 0.5).with_tol(1e-6);
2884        model
2885            .fit(&x, &y)
2886            .expect("Fit should succeed with valid test data");
2887
2888        assert!(model.is_fitted());
2889    }
2890}