aprender/linear_model/mod.rs
1//! Linear models for regression.
2//!
3//! Includes Ordinary Least Squares (OLS) and regularized regression.
4
5use crate::error::Result;
6use crate::metrics::r_squared;
7use crate::primitives::{Matrix, Vector};
8use crate::traits::Estimator;
9use serde::{Deserialize, Serialize};
10use std::fs;
11use std::path::Path;
12
13/// Ordinary Least Squares (OLS) linear regression.
14///
15/// Fits a linear model by minimizing the residual sum of squares between
16/// observed targets and predicted targets. The model equation is:
17///
18/// ```text
19/// y = X β + ε
20/// ```
21///
22/// where `β` is the coefficient vector and `ε` is random error.
23///
24/// # Solver
25///
26/// Uses normal equations: `β = (X^T X)^-1 X^T y` via Cholesky decomposition.
27///
28/// # Examples
29///
30/// ```
31/// use aprender::prelude::*;
32///
33/// // Simple linear regression: y = 2x + 1
34/// let x = Matrix::from_vec(4, 1, vec![
35/// 1.0,
36/// 2.0,
37/// 3.0,
38/// 4.0,
39/// ]).expect("Valid matrix dimensions");
40/// let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
41///
42/// let mut model = LinearRegression::new();
43/// model.fit(&x, &y).expect("Fit should succeed with valid data");
44///
45/// let predictions = model.predict(&x);
46/// let r2 = model.score(&x, &y);
47/// assert!(r2 > 0.99);
48/// ```
49///
50/// # Performance
51///
52/// - Time complexity: O(n²p + p³) where n = samples, p = features
53/// - Space complexity: O(np)
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct LinearRegression {
56 /// Coefficients for features (excluding intercept).
57 coefficients: Option<Vector<f32>>,
58 /// Intercept (bias) term.
59 intercept: f32,
60 /// Whether to fit an intercept.
61 fit_intercept: bool,
62}
63
64impl Default for LinearRegression {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl LinearRegression {
71 /// Creates a new `LinearRegression` with default settings.
72 #[must_use]
73 pub fn new() -> Self {
74 Self {
75 coefficients: None,
76 intercept: 0.0,
77 fit_intercept: true,
78 }
79 }
80
81 /// Sets whether to fit an intercept term.
82 #[must_use]
83 pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
84 self.fit_intercept = fit_intercept;
85 self
86 }
87
88 /// Returns the coefficients (excluding intercept).
89 ///
90 /// # Panics
91 ///
92 /// Panics if model is not fitted.
93 #[must_use]
94 pub fn coefficients(&self) -> &Vector<f32> {
95 self.coefficients
96 .as_ref()
97 .expect("Model not fitted. Call fit() first.")
98 }
99
100 /// Returns the intercept term.
101 #[must_use]
102 pub fn intercept(&self) -> f32 {
103 self.intercept
104 }
105
106 /// Returns true if the model has been fitted.
107 #[must_use]
108 pub fn is_fitted(&self) -> bool {
109 self.coefficients.is_some()
110 }
111
112 /// Saves the model to a binary file using bincode.
113 ///
114 /// # Errors
115 ///
116 /// Returns an error if serialization or file writing fails.
117 pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
118 let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
119 fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
120 Ok(())
121 }
122
123 /// Loads a model from a binary file.
124 ///
125 /// # Errors
126 ///
127 /// Returns an error if file reading or deserialization fails.
128 pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
129 let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
130 let model =
131 bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
132 Ok(model)
133 }
134
135 /// Saves the model to `SafeTensors` format.
136 ///
137 /// `SafeTensors` format is compatible with:
138 /// - `HuggingFace` ecosystem
139 /// - Ollama (can convert to GGUF)
140 /// - `PyTorch`, TensorFlow
141 /// - realizar inference engine
142 ///
143 /// # Errors
144 ///
145 /// Returns an error if:
146 /// - Model is not fitted
147 /// - Serialization fails
148 /// - File writing fails
149 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
150 use crate::serialization::safetensors;
151 use std::collections::BTreeMap;
152
153 // Verify model is fitted
154 let coefficients = self
155 .coefficients
156 .as_ref()
157 .ok_or("Cannot save unfitted model. Call fit() first.")?;
158
159 // Prepare tensors (BTreeMap ensures deterministic ordering)
160 let mut tensors = BTreeMap::new();
161
162 // Coefficients tensor
163 let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
164 let coef_shape = vec![coefficients.len()];
165 tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
166
167 // Intercept tensor
168 let intercept_data = vec![self.intercept];
169 let intercept_shape = vec![1];
170 tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
171
172 // Save to SafeTensors format
173 safetensors::save_safetensors(path, &tensors)?;
174 Ok(())
175 }
176
177 /// Loads a model from `SafeTensors` format.
178 ///
179 /// # Errors
180 ///
181 /// Returns an error if:
182 /// - File reading fails
183 /// - `SafeTensors` format is invalid
184 /// - Required tensors are missing
185 pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
186 use crate::serialization::safetensors;
187
188 // Load SafeTensors file
189 let (metadata, raw_data) = safetensors::load_safetensors(path)?;
190
191 // Extract coefficients tensor
192 let coef_meta = metadata
193 .get("coefficients")
194 .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
195 let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
196
197 // Extract intercept tensor
198 let intercept_meta = metadata
199 .get("intercept")
200 .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
201 let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
202
203 // Validate intercept shape
204 if intercept_data.len() != 1 {
205 return Err(format!(
206 "Invalid intercept tensor: expected 1 value, got {}",
207 intercept_data.len()
208 ));
209 }
210
211 // Construct model
212 Ok(Self {
213 coefficients: Some(Vector::from_vec(coef_data)),
214 intercept: intercept_data[0],
215 fit_intercept: true, // Default to true for loaded models
216 })
217 }
218
219 /// Adds an intercept column of ones to the design matrix.
220 fn add_intercept_column(x: &Matrix<f32>) -> Matrix<f32> {
221 let (n_rows, n_cols) = x.shape();
222 let mut data = Vec::with_capacity(n_rows * (n_cols + 1));
223
224 for i in 0..n_rows {
225 data.push(1.0); // Intercept column
226 for j in 0..n_cols {
227 data.push(x.get(i, j));
228 }
229 }
230
231 Matrix::from_vec(n_rows, n_cols + 1, data)
232 .expect("Internal error: failed to create design matrix")
233 }
234}
235
236// Contract: linear-models-v1, equation = "ols_fit"
237impl Estimator for LinearRegression {
238 /// Fits the linear regression model using normal equations.
239 ///
240 /// Solves: β = (X^T X)^-1 X^T y
241 ///
242 /// # Errors
243 ///
244 /// Returns an error if:
245 /// - Input dimensions don't match
246 /// - Not enough samples for the number of features (underdetermined system)
247 /// - Matrix is singular (not positive definite)
248 fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
249 let (n_samples, n_features) = x.shape();
250
251 if n_samples != y.len() {
252 return Err("Number of samples must match target length".into());
253 }
254
255 if n_samples == 0 {
256 return Err("Cannot fit with zero samples".into());
257 }
258
259 // Check for underdetermined system
260 // When fitting intercept, we need n_samples >= n_features + 1
261 // Without intercept, we need n_samples >= n_features
262 let required_samples = if self.fit_intercept {
263 n_features + 1
264 } else {
265 n_features
266 };
267
268 if n_samples < required_samples {
269 return Err(
270 "Insufficient samples: LinearRegression requires at least as many samples as \
271 features (plus 1 if fitting intercept). Consider using Ridge regression or \
272 collecting more training data"
273 .into(),
274 );
275 }
276
277 // Create design matrix (with or without intercept)
278 let x_design = if self.fit_intercept {
279 Self::add_intercept_column(x)
280 } else {
281 x.clone()
282 };
283
284 // Compute X^T X
285 let xt = x_design.transpose();
286 let xtx = xt.matmul(&x_design)?;
287
288 // Compute X^T y
289 let xty = xt.matvec(y)?;
290
291 // Solve normal equations via Cholesky decomposition
292 let beta = xtx.cholesky_solve(&xty)?;
293
294 // Extract intercept and coefficients
295 if self.fit_intercept {
296 self.intercept = beta[0];
297 self.coefficients = Some(beta.slice(1, n_features + 1));
298 } else {
299 self.intercept = 0.0;
300 self.coefficients = Some(beta);
301 }
302
303 Ok(())
304 }
305
306 /// Predicts target values for input data.
307 ///
308 /// # Panics
309 ///
310 /// Panics if model is not fitted.
311 fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
312 let coefficients = self
313 .coefficients
314 .as_ref()
315 .expect("Model not fitted. Call fit() first.");
316
317 let result = x
318 .matvec(coefficients)
319 .expect("Matrix dimensions don't match coefficients");
320
321 result.add_scalar(self.intercept)
322 }
323
324 /// Computes the R² score.
325 fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
326 let y_pred = self.predict(x);
327 r_squared(&y_pred, y)
328 }
329}
330
331/// Ridge regression with L2 regularization.
332///
333/// Fits a linear model with L2 penalty on coefficient magnitudes.
334/// The optimization objective is:
335///
336/// ```text
337/// minimize ||y - Xβ||² + α||β||²
338/// ```
339///
340/// where `α` (alpha) controls the regularization strength.
341///
342/// # Solver
343///
344/// Uses regularized normal equations: `β = (X^T X + αI)^-1 X^T y`
345///
346/// # When to use Ridge
347///
348/// - When you have many correlated features (multicollinearity)
349/// - To prevent overfitting with limited samples
350/// - When all features are expected to contribute
351///
352/// # Examples
353///
354/// ```
355/// use aprender::prelude::*;
356/// use aprender::linear_model::Ridge;
357///
358/// // Data with some noise
359/// let x = Matrix::from_vec(5, 2, vec![
360/// 1.0, 2.0,
361/// 2.0, 3.0,
362/// 3.0, 4.0,
363/// 4.0, 5.0,
364/// 5.0, 6.0,
365/// ]).expect("Valid matrix dimensions");
366/// let y = Vector::from_slice(&[5.0, 8.0, 11.0, 14.0, 17.0]);
367///
368/// let mut model = Ridge::new(1.0); // alpha = 1.0
369/// model.fit(&x, &y).expect("Fit should succeed with valid data");
370///
371/// let predictions = model.predict(&x);
372/// let r2 = model.score(&x, &y);
373/// assert!(r2 > 0.9);
374/// ```
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct Ridge {
377 /// Regularization strength (lambda/alpha).
378 alpha: f32,
379 /// Coefficients for features (excluding intercept).
380 coefficients: Option<Vector<f32>>,
381 /// Intercept (bias) term.
382 intercept: f32,
383 /// Whether to fit an intercept.
384 fit_intercept: bool,
385}
386
387include!("lasso.rs");
388include!("lasso_impl.rs");
389include!("elastic_net.rs");
390include!("input.rs");