linreg_core/regularized/ridge.rs
1//! Ridge regression (L2-regularized linear regression).
2//!
3//! This module provides a wrapper around the elastic net implementation with `alpha=0.0`.
4
5use crate::core::{aic, bic, log_likelihood};
6use crate::error::Result;
7use crate::linalg::Matrix;
8use crate::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
9use crate::regularized::preprocess::predict;
10use crate::serialization::types::ModelType;
11use crate::impl_serialization;
12use serde::{Deserialize, Serialize};
13
14/// Options for ridge regression fitting.
15///
16/// Configuration options for ridge regression (L2-regularized linear regression).
17///
18/// # Fields
19///
20/// - `lambda` - Regularization strength (≥ 0, higher = more shrinkage)
21/// - `intercept` - Whether to include an intercept term
22/// - `standardize` - Whether to standardize predictors to unit variance
23/// - `max_iter` - Maximum coordinate descent iterations
24/// - `tol` - Convergence tolerance on coefficient changes
25/// - `warm_start` - Optional initial coefficient values for warm starts
26/// - `weights` - Optional observation weights
27///
28/// # Example
29///
30/// ```
31/// # use linreg_core::regularized::ridge::RidgeFitOptions;
32/// let options = RidgeFitOptions {
33/// lambda: 1.0,
34/// intercept: true,
35/// standardize: true,
36/// ..Default::default()
37/// };
38/// ```
39#[derive(Clone, Debug)]
40pub struct RidgeFitOptions {
41 pub lambda: f64,
42 pub intercept: bool,
43 pub standardize: bool,
44 pub max_iter: usize, // Added for consistency
45 pub tol: f64, // Added for consistency
46 pub warm_start: Option<Vec<f64>>,
47 pub weights: Option<Vec<f64>>, // Observation weights
48}
49
50impl Default for RidgeFitOptions {
51 fn default() -> Self {
52 RidgeFitOptions {
53 lambda: 1.0,
54 intercept: true,
55 standardize: true,
56 max_iter: 100000,
57 tol: 1e-7,
58 warm_start: None,
59 weights: None,
60 }
61 }
62}
63
64/// Result of a ridge regression fit.
65///
66/// Contains the fitted model coefficients, predictions, and diagnostic metrics.
67///
68/// # Fields
69///
70/// - `lambda` - The regularization strength used
71/// - `intercept` - Intercept coefficient (never penalized)
72/// - `coefficients` - Slope coefficients (penalized)
73/// - `fitted_values` - Predicted values on training data
74/// - `residuals` - Residuals (y - fitted_values)
75/// - `df` - Approximate effective degrees of freedom
76/// - `r_squared` - Coefficient of determination
77/// - `adj_r_squared` - Adjusted R²
78/// - `mse` - Mean squared error
79/// - `rmse` - Root mean squared error
80/// - `mae` - Mean absolute error
81/// - `log_likelihood` - Log-likelihood of the model (for model comparison)
82/// - `aic` - Akaike Information Criterion (lower = better)
83/// - `bic` - Bayesian Information Criterion (lower = better)
84///
85/// # Example
86///
87/// ```
88/// # use linreg_core::regularized::ridge::{ridge_fit, RidgeFitOptions};
89/// # use linreg_core::linalg::Matrix;
90/// # let y = vec![2.0, 4.0, 6.0, 8.0];
91/// # let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
92/// # let options = RidgeFitOptions { lambda: 0.1, intercept: true, standardize: false, ..Default::default() };
93/// let fit = ridge_fit(&x, &y, &options).unwrap();
94///
95/// // Access model coefficients
96/// println!("Intercept: {}", fit.intercept);
97/// println!("Slopes: {:?}", fit.coefficients);
98///
99/// // Access predictions and diagnostics
100/// println!("R²: {}", fit.r_squared);
101/// println!("RMSE: {}", fit.rmse);
102/// println!("AIC: {}", fit.aic);
103/// # Ok::<(), linreg_core::Error>(())
104/// ```
105#[derive(Clone, Debug, Serialize, Deserialize)]
106pub struct RidgeFit {
107 pub lambda: f64,
108 pub intercept: f64,
109 pub coefficients: Vec<f64>,
110 pub fitted_values: Vec<f64>,
111 pub residuals: Vec<f64>,
112 pub df: f64, // Still computed, though approximation
113 pub r_squared: f64,
114 pub adj_r_squared: f64,
115 pub mse: f64,
116 pub rmse: f64,
117 pub mae: f64,
118 pub log_likelihood: f64,
119 pub aic: f64,
120 pub bic: f64,
121}
122
123/// Fits ridge regression for a single lambda value.
124///
125/// Ridge regression adds an L2 penalty to the coefficients, which helps with
126/// multicollinearity and overfitting. The intercept is never penalized.
127///
128/// # Arguments
129///
130/// * `x` - Design matrix (n rows × p columns including intercept)
131/// * `y` - Response variable (n observations)
132/// * `options` - Configuration options for ridge regression
133///
134/// # Returns
135///
136/// A `RidgeFit` containing coefficients, fitted values, residuals, and metrics.
137///
138/// # Example
139///
140/// ```
141/// # use linreg_core::regularized::ridge::{ridge_fit, RidgeFitOptions};
142/// # use linreg_core::linalg::Matrix;
143/// let y = vec![2.0, 4.0, 6.0, 8.0];
144/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
145///
146/// let options = RidgeFitOptions {
147/// lambda: 0.1,
148/// intercept: true,
149/// standardize: false,
150/// ..Default::default()
151/// };
152///
153/// let fit = ridge_fit(&x, &y, &options).unwrap();
154/// assert!(fit.coefficients.len() == 1); // One slope coefficient
155/// assert!(fit.r_squared > 0.9); // Good fit for linear data
156/// # Ok::<(), linreg_core::Error>(())
157/// ```
158pub fn ridge_fit(x: &Matrix, y: &[f64], options: &RidgeFitOptions) -> Result<RidgeFit> {
159 // DEBUG: Print lambda info
160 // #[cfg(debug_assertions)]
161 // {
162 // eprintln!("DEBUG ridge_fit: user_lambda = {}, standardize = {}", options.lambda, options.standardize);
163 // }
164
165 let en_options = ElasticNetOptions {
166 lambda: options.lambda,
167 alpha: 0.0, // Ridge
168 intercept: options.intercept,
169 standardize: options.standardize,
170 max_iter: options.max_iter,
171 tol: options.tol,
172 penalty_factor: None,
173 warm_start: options.warm_start.clone(),
174 weights: options.weights.clone(),
175 coefficient_bounds: None,
176 };
177
178 let fit = elastic_net_fit(x, y, &en_options)?;
179
180 // #[cfg(debug_assertions)]
181 // {
182 // eprintln!("DEBUG ridge_fit: fit.intercept = {}, fit.coefficients[0] = {}", fit.intercept,
183 // fit.coefficients.first().unwrap_or(&0.0));
184 // }
185
186 // Approximation of degrees of freedom for ridge regression.
187 //
188 // The true effective df requires SVD: sum(eigenvalues / (eigenvalues + lambda)).
189 // Since coordinate descent doesn't compute the SVD, we use a closed-form approximation
190 // that works well when X is standardized: df ≈ p / (1 + lambda).
191 //
192 // This approximation is reasonable for most practical purposes. For exact df,
193 // users would need to implement SVD-based calculation separately.
194 let p = x.cols;
195 let df = (p as f64) / (1.0 + options.lambda);
196
197 // Model selection criteria
198 let n = y.len();
199 let ss_res: f64 = fit.residuals.iter().map(|&r| r * r).sum();
200 let ll = log_likelihood(n, fit.mse, ss_res);
201 let n_coef = fit.coefficients.len() + 1; // coefficients + intercept
202 let aic_val = aic(ll, n_coef);
203 let bic_val = bic(ll, n_coef, n);
204
205 Ok(RidgeFit {
206 lambda: fit.lambda,
207 intercept: fit.intercept,
208 coefficients: fit.coefficients,
209 fitted_values: fit.fitted_values,
210 residuals: fit.residuals,
211 df,
212 r_squared: fit.r_squared,
213 adj_r_squared: fit.adj_r_squared,
214 mse: fit.mse,
215 rmse: fit.rmse,
216 mae: fit.mae,
217 log_likelihood: ll,
218 aic: aic_val,
219 bic: bic_val,
220 })
221}
222
223/// Makes predictions using a ridge regression fit.
224///
225/// Computes predictions for new observations using the fitted ridge regression model.
226///
227/// # Arguments
228///
229/// * `fit` - Fitted ridge regression model
230/// * `x_new` - New design matrix (same number of columns as training data)
231///
232/// # Returns
233///
234/// Vector of predicted values.
235///
236/// # Example
237///
238/// ```
239/// # use linreg_core::regularized::ridge::{ridge_fit, predict_ridge, RidgeFitOptions};
240/// # use linreg_core::linalg::Matrix;
241/// // Training data
242/// let y = vec![2.0, 4.0, 6.0, 8.0];
243/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
244///
245/// let options = RidgeFitOptions {
246/// lambda: 0.1,
247/// intercept: true,
248/// standardize: false,
249/// ..Default::default()
250/// };
251/// let fit = ridge_fit(&x, &y, &options).unwrap();
252///
253/// // Predict on new data
254/// let x_new = Matrix::new(2, 2, vec![1.0, 5.0, 1.0, 6.0]);
255/// let predictions = predict_ridge(&fit, &x_new);
256///
257/// assert_eq!(predictions.len(), 2);
258/// // Predictions should be close to [10.0, 12.0] for the linear relationship y = 2*x
259/// # Ok::<(), linreg_core::Error>(())
260/// ```
261///
262/// # Arguments
263///
264/// * `fit` - Fitted ridge regression model from [`ridge_fit`]
265/// * `x_new` - Design matrix for new observations (n_new × p, including intercept column)
266///
267/// # Returns
268///
269/// Vector of predicted values, one per row in `x_new`.
270///
271/// # Panics
272///
273/// Panics if `x_new.cols()` does not match the number of coefficients in `fit` (including intercept).
274///
275/// [`ridge_fit`]: crate::regularized::ridge_fit
276pub fn predict_ridge(fit: &RidgeFit, x_new: &Matrix) -> Vec<f64> {
277 predict(x_new, fit.intercept, &fit.coefficients)
278}
279
280// ============================================================================
281// Model Serialization Traits
282// ============================================================================
283
284// Generate ModelSave and ModelLoad implementations using macro
285impl_serialization!(RidgeFit, ModelType::Ridge, "Ridge");
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_ridge_fit_simple() {
293 let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0];
294 let x = Matrix::new(4, 2, x_data);
295 let y = vec![2.0, 4.0, 6.0, 8.0];
296
297 let options = RidgeFitOptions {
298 lambda: 0.1,
299 intercept: true,
300 standardize: false,
301 ..Default::default()
302 };
303
304 let fit = ridge_fit(&x, &y, &options).unwrap();
305
306 // OLS: intercept ≈ 0, slope ≈ 2
307 assert!((fit.coefficients[0] - 2.0).abs() < 0.2);
308 assert!(fit.intercept.abs() < 0.5);
309 }
310}