Skip to main content

aprender/linear_model/
mod.rs

1//! Linear models for regression.
2//!
3//! Includes Ordinary Least Squares (OLS) and regularized regression.
4
5use crate::error::Result;
6use crate::metrics::r_squared;
7use crate::primitives::{Matrix, Vector};
8use crate::traits::Estimator;
9use serde::{Deserialize, Serialize};
10use std::fs;
11use std::path::Path;
12
13/// Ordinary Least Squares (OLS) linear regression.
14///
15/// Fits a linear model by minimizing the residual sum of squares between
16/// observed targets and predicted targets. The model equation is:
17///
18/// ```text
19/// y = X β + ε
20/// ```
21///
22/// where `β` is the coefficient vector and `ε` is random error.
23///
24/// # Solver
25///
26/// Uses normal equations: `β = (X^T X)^-1 X^T y` via Cholesky decomposition.
27///
28/// # Examples
29///
30/// ```
31/// use aprender::prelude::*;
32///
33/// // Simple linear regression: y = 2x + 1
34/// let x = Matrix::from_vec(4, 1, vec![
35///     1.0,
36///     2.0,
37///     3.0,
38///     4.0,
39/// ]).expect("Valid matrix dimensions");
40/// let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
41///
42/// let mut model = LinearRegression::new();
43/// model.fit(&x, &y).expect("Fit should succeed with valid data");
44///
45/// let predictions = model.predict(&x);
46/// let r2 = model.score(&x, &y);
47/// assert!(r2 > 0.99);
48/// ```
49///
50/// # Performance
51///
52/// - Time complexity: O(n²p + p³) where n = samples, p = features
53/// - Space complexity: O(np)
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct LinearRegression {
56    /// Coefficients for features (excluding intercept).
57    coefficients: Option<Vector<f32>>,
58    /// Intercept (bias) term.
59    intercept: f32,
60    /// Whether to fit an intercept.
61    fit_intercept: bool,
62}
63
64impl Default for LinearRegression {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl LinearRegression {
71    /// Creates a new `LinearRegression` with default settings.
72    #[must_use]
73    pub fn new() -> Self {
74        Self {
75            coefficients: None,
76            intercept: 0.0,
77            fit_intercept: true,
78        }
79    }
80
81    /// Sets whether to fit an intercept term.
82    #[must_use]
83    pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
84        self.fit_intercept = fit_intercept;
85        self
86    }
87
88    /// Returns the coefficients (excluding intercept).
89    ///
90    /// # Panics
91    ///
92    /// Panics if model is not fitted.
93    #[must_use]
94    pub fn coefficients(&self) -> &Vector<f32> {
95        self.coefficients
96            .as_ref()
97            .expect("Model not fitted. Call fit() first.")
98    }
99
100    /// Returns the intercept term.
101    #[must_use]
102    pub fn intercept(&self) -> f32 {
103        self.intercept
104    }
105
106    /// Returns true if the model has been fitted.
107    #[must_use]
108    pub fn is_fitted(&self) -> bool {
109        self.coefficients.is_some()
110    }
111
112    /// Saves the model to a binary file using bincode.
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if serialization or file writing fails.
117    pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
118        let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
119        fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
120        Ok(())
121    }
122
123    /// Loads a model from a binary file.
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if file reading or deserialization fails.
128    pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
129        let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
130        let model =
131            bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
132        Ok(model)
133    }
134
135    /// Saves the model to `SafeTensors` format.
136    ///
137    /// `SafeTensors` format is compatible with:
138    /// - `HuggingFace` ecosystem
139    /// - Ollama (can convert to GGUF)
140    /// - `PyTorch`, TensorFlow
141    /// - realizar inference engine
142    ///
143    /// # Errors
144    ///
145    /// Returns an error if:
146    /// - Model is not fitted
147    /// - Serialization fails
148    /// - File writing fails
149    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
150        use crate::serialization::safetensors;
151        use std::collections::BTreeMap;
152
153        // Verify model is fitted
154        let coefficients = self
155            .coefficients
156            .as_ref()
157            .ok_or("Cannot save unfitted model. Call fit() first.")?;
158
159        // Prepare tensors (BTreeMap ensures deterministic ordering)
160        let mut tensors = BTreeMap::new();
161
162        // Coefficients tensor
163        let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
164        let coef_shape = vec![coefficients.len()];
165        tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
166
167        // Intercept tensor
168        let intercept_data = vec![self.intercept];
169        let intercept_shape = vec![1];
170        tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
171
172        // Save to SafeTensors format
173        safetensors::save_safetensors(path, &tensors)?;
174        Ok(())
175    }
176
177    /// Loads a model from `SafeTensors` format.
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if:
182    /// - File reading fails
183    /// - `SafeTensors` format is invalid
184    /// - Required tensors are missing
185    pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
186        use crate::serialization::safetensors;
187
188        // Load SafeTensors file
189        let (metadata, raw_data) = safetensors::load_safetensors(path)?;
190
191        // Extract coefficients tensor
192        let coef_meta = metadata
193            .get("coefficients")
194            .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
195        let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
196
197        // Extract intercept tensor
198        let intercept_meta = metadata
199            .get("intercept")
200            .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
201        let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
202
203        // Validate intercept shape
204        if intercept_data.len() != 1 {
205            return Err(format!(
206                "Invalid intercept tensor: expected 1 value, got {}",
207                intercept_data.len()
208            ));
209        }
210
211        // Construct model
212        Ok(Self {
213            coefficients: Some(Vector::from_vec(coef_data)),
214            intercept: intercept_data[0],
215            fit_intercept: true, // Default to true for loaded models
216        })
217    }
218
219    /// Adds an intercept column of ones to the design matrix.
220    fn add_intercept_column(x: &Matrix<f32>) -> Matrix<f32> {
221        let (n_rows, n_cols) = x.shape();
222        let mut data = Vec::with_capacity(n_rows * (n_cols + 1));
223
224        for i in 0..n_rows {
225            data.push(1.0); // Intercept column
226            for j in 0..n_cols {
227                data.push(x.get(i, j));
228            }
229        }
230
231        Matrix::from_vec(n_rows, n_cols + 1, data)
232            .expect("Internal error: failed to create design matrix")
233    }
234}
235
236// Contract: linear-models-v1, equation = "ols_fit"
237impl Estimator for LinearRegression {
238    /// Fits the linear regression model using normal equations.
239    ///
240    /// Solves: β = (X^T X)^-1 X^T y
241    ///
242    /// # Errors
243    ///
244    /// Returns an error if:
245    /// - Input dimensions don't match
246    /// - Not enough samples for the number of features (underdetermined system)
247    /// - Matrix is singular (not positive definite)
248    fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
249        let (n_samples, n_features) = x.shape();
250
251        if n_samples != y.len() {
252            return Err("Number of samples must match target length".into());
253        }
254
255        if n_samples == 0 {
256            return Err("Cannot fit with zero samples".into());
257        }
258
259        // Check for underdetermined system
260        // When fitting intercept, we need n_samples >= n_features + 1
261        // Without intercept, we need n_samples >= n_features
262        let required_samples = if self.fit_intercept {
263            n_features + 1
264        } else {
265            n_features
266        };
267
268        if n_samples < required_samples {
269            return Err(
270                "Insufficient samples: LinearRegression requires at least as many samples as \
271                 features (plus 1 if fitting intercept). Consider using Ridge regression or \
272                 collecting more training data"
273                    .into(),
274            );
275        }
276
277        // Create design matrix (with or without intercept)
278        let x_design = if self.fit_intercept {
279            Self::add_intercept_column(x)
280        } else {
281            x.clone()
282        };
283
284        // Compute X^T X
285        let xt = x_design.transpose();
286        let xtx = xt.matmul(&x_design)?;
287
288        // Compute X^T y
289        let xty = xt.matvec(y)?;
290
291        // Solve normal equations via Cholesky decomposition
292        let beta = xtx.cholesky_solve(&xty)?;
293
294        // Extract intercept and coefficients
295        if self.fit_intercept {
296            self.intercept = beta[0];
297            self.coefficients = Some(beta.slice(1, n_features + 1));
298        } else {
299            self.intercept = 0.0;
300            self.coefficients = Some(beta);
301        }
302
303        Ok(())
304    }
305
306    /// Predicts target values for input data.
307    ///
308    /// # Panics
309    ///
310    /// Panics if model is not fitted.
311    fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
312        let coefficients = self
313            .coefficients
314            .as_ref()
315            .expect("Model not fitted. Call fit() first.");
316
317        let result = x
318            .matvec(coefficients)
319            .expect("Matrix dimensions don't match coefficients");
320
321        result.add_scalar(self.intercept)
322    }
323
324    /// Computes the R² score.
325    fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
326        let y_pred = self.predict(x);
327        r_squared(&y_pred, y)
328    }
329}
330
331/// Ridge regression with L2 regularization.
332///
333/// Fits a linear model with L2 penalty on coefficient magnitudes.
334/// The optimization objective is:
335///
336/// ```text
337/// minimize ||y - Xβ||² + α||β||²
338/// ```
339///
340/// where `α` (alpha) controls the regularization strength.
341///
342/// # Solver
343///
344/// Uses regularized normal equations: `β = (X^T X + αI)^-1 X^T y`
345///
346/// # When to use Ridge
347///
348/// - When you have many correlated features (multicollinearity)
349/// - To prevent overfitting with limited samples
350/// - When all features are expected to contribute
351///
352/// # Examples
353///
354/// ```
355/// use aprender::prelude::*;
356/// use aprender::linear_model::Ridge;
357///
358/// // Data with some noise
359/// let x = Matrix::from_vec(5, 2, vec![
360///     1.0, 2.0,
361///     2.0, 3.0,
362///     3.0, 4.0,
363///     4.0, 5.0,
364///     5.0, 6.0,
365/// ]).expect("Valid matrix dimensions");
366/// let y = Vector::from_slice(&[5.0, 8.0, 11.0, 14.0, 17.0]);
367///
368/// let mut model = Ridge::new(1.0);  // alpha = 1.0
369/// model.fit(&x, &y).expect("Fit should succeed with valid data");
370///
371/// let predictions = model.predict(&x);
372/// let r2 = model.score(&x, &y);
373/// assert!(r2 > 0.9);
374/// ```
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct Ridge {
377    /// Regularization strength (lambda/alpha).
378    alpha: f32,
379    /// Coefficients for features (excluding intercept).
380    coefficients: Option<Vector<f32>>,
381    /// Intercept (bias) term.
382    intercept: f32,
383    /// Whether to fit an intercept.
384    fit_intercept: bool,
385}
386
387include!("lasso.rs");
388include!("lasso_impl.rs");
389include!("elastic_net.rs");
390include!("input.rs");