linreg_core/regularized/elastic_net.rs
1//! Elastic Net regression (L1 + L2 regularized linear regression).
2//!
3//! This module provides a generalized elastic net implementation using cyclical
4//! coordinate descent with soft-thresholding and active set convergence strategies.
5//! It serves as the core engine for both Lasso (`alpha=1.0`) and Ridge (`alpha=0.0`).
6//!
7//! # Objective Function
8//!
9//! Minimizes over `(β₀, β)`:
10//!
11//! ```text
12//! (1/(2n)) * ||y - β₀ - Xβ||² + λ * [ (1-α)||β||₂²/2 + α||β||₁ ]
13//! ```
14//!
15//! Note on scaling: The internal implementation works with standardized data (unit norm columns).
16//! The lambda parameter is adjusted internally to match the scale expected by the formulation above.
17
18use crate::core::{aic, bic, log_likelihood};
19use crate::error::{Error, Result};
20use crate::linalg::Matrix;
21use crate::regularized::preprocess::{
22 predict, standardize_xy, unstandardize_coefficients, StandardizeOptions,
23};
24
25#[cfg(feature = "wasm")]
26use serde::Serialize;
27
28/// Soft-thresholding operator: S(z, γ) = sign(z) * max(|z| - γ, 0)
29///
30/// This is the key operation in Lasso and Elastic Net regression that applies
31/// the L1 penalty, producing sparse solutions by shrinking small values to zero.
32///
33/// # Arguments
34///
35/// * `z` - Input value to be thresholded
36/// * `gamma` - Threshold value (must be non-negative)
37///
38/// # Returns
39///
40/// - `z - gamma` if `z > gamma`
41/// - `z + gamma` if `z < -gamma`
42/// - `0` otherwise (when `|z| <= gamma`)
43///
44/// # Panics
45///
46/// Panics if `gamma` is negative.
47///
48/// # Example
49///
50/// ```
51/// # use linreg_core::regularized::elastic_net::soft_threshold;
52/// // Values above threshold are reduced
53/// assert_eq!(soft_threshold(5.0, 2.0), 3.0);
54///
55/// // Values below threshold are set to zero
56/// assert_eq!(soft_threshold(1.0, 2.0), 0.0);
57///
58/// // Negative values work symmetrically
59/// assert_eq!(soft_threshold(-5.0, 2.0), -3.0);
60/// assert_eq!(soft_threshold(-1.0, 2.0), 0.0);
61/// ```
62#[inline]
63pub fn soft_threshold(z: f64, gamma: f64) -> f64 {
64 if gamma < 0.0 {
65 panic!("Soft threshold gamma must be non-negative");
66 }
67 if z > gamma {
68 z - gamma
69 } else if z < -gamma {
70 z + gamma
71 } else {
72 0.0
73 }
74}
75
76/// Options for elastic net fitting.
77///
78/// Configuration options for elastic net regression, which combines L1 and L2 penalties.
79///
80/// # Fields
81///
82/// - `lambda` - Regularization strength (≥ 0, higher = more regularization)
83/// - `alpha` - Mixing parameter (0 = Ridge, 1 = Lasso, 0.5 = equal mix)
84/// - `intercept` - Whether to include an intercept term
85/// - `standardize` - Whether to standardize predictors to unit variance
86/// - `max_iter` - Maximum coordinate descent iterations
87/// - `tol` - Convergence tolerance on coefficient changes
88/// - `penalty_factor` - Optional per-feature penalty multipliers
89/// - `warm_start` - Optional initial coefficient values for warm starts
90/// - `weights` - Optional observation weights
91/// - `coefficient_bounds` - Optional (lower, upper) bounds for each coefficient
92///
93/// # Example
94///
95/// ```
96/// # use linreg_core::regularized::elastic_net::ElasticNetOptions;
97/// let options = ElasticNetOptions {
98/// lambda: 0.1,
99/// alpha: 0.5, // Equal mix of L1 and L2
100/// intercept: true,
101/// standardize: true,
102/// ..Default::default()
103/// };
104/// ```
105#[derive(Clone, Debug)]
106pub struct ElasticNetOptions {
107 /// Regularization strength (lambda >= 0)
108 pub lambda: f64,
109 /// Elastic net mixing parameter (0 <= alpha <= 1).
110 /// alpha=1 is Lasso, alpha=0 is Ridge.
111 pub alpha: f64,
112 /// Whether to include an intercept term
113 pub intercept: bool,
114 /// Whether to standardize predictors
115 pub standardize: bool,
116 /// Maximum coordinate descent iterations
117 pub max_iter: usize,
118 /// Convergence tolerance on coefficient changes
119 pub tol: f64,
120 /// Per-feature penalty factors (optional).
121 /// If None, all features have penalty factor 1.0.
122 pub penalty_factor: Option<Vec<f64>>,
123 /// Initial coefficients for warm start (optional).
124 /// If provided, optimization starts from these values instead of zero.
125 /// Used for efficient pathwise coordinate descent.
126 pub warm_start: Option<Vec<f64>>,
127 /// Observation weights (optional).
128 /// If provided, must have length equal to the number of observations.
129 /// Weights are normalized to sum to 1 internally.
130 pub weights: Option<Vec<f64>>,
131 /// Coefficient bounds: (lower, upper) for each predictor.
132 /// If None, uses (-inf, +inf) for all coefficients (no bounds).
133 ///
134 /// The bounds vector length must equal the number of predictors (excluding intercept).
135 /// For each predictor, the coefficient will be clamped to [lower, upper] after
136 /// each coordinate descent update.
137 ///
138 /// # Examples
139 /// * Non-negative least squares: `Some(vec![(0.0, f64::INFINITY); p])`
140 /// * Upper bound only: `Some(vec![(-f64::INFINITY, 10.0); p])`
141 /// * Both bounds: `Some(vec![(-5.0, 5.0); p])`
142 ///
143 /// # Notes
144 /// * Bounds are applied to coefficients on the ORIGINAL scale, not standardized scale
145 /// * The intercept is never bounded
146 /// * Each pair must satisfy `lower <= upper`
147 pub coefficient_bounds: Option<Vec<(f64, f64)>>,
148}
149
150impl Default for ElasticNetOptions {
151 fn default() -> Self {
152 ElasticNetOptions {
153 lambda: 1.0,
154 alpha: 1.0, // Lasso default
155 intercept: true,
156 standardize: true,
157 max_iter: 100000,
158 tol: 1e-7,
159 penalty_factor: None,
160 warm_start: None,
161 weights: None,
162 coefficient_bounds: None,
163 }
164 }
165}
166
167/// Result of an elastic net fit.
168///
169/// Contains the fitted model coefficients, convergence information, and diagnostic metrics.
170///
171/// # Fields
172///
173/// - `lambda` - The regularization strength used
174/// - `alpha` - The elastic net mixing parameter (0 = Ridge, 1 = Lasso)
175/// - `intercept` - Intercept coefficient (never penalized)
176/// - `coefficients` - Slope coefficients (may be sparse for high alpha)
177/// - `fitted_values` - Predicted values on training data
178/// - `residuals` - Residuals (y - fitted_values)
179/// - `n_nonzero` - Number of non-zero coefficients (excluding intercept)
180/// - `iterations` - Number of coordinate descent iterations performed
181/// - `converged` - Whether the algorithm converged
182/// - `r_squared` - Coefficient of determination
183/// - `adj_r_squared` - Adjusted R²
184/// - `mse` - Mean squared error
185/// - `rmse` - Root mean squared error
186/// - `mae` - Mean absolute error
187/// - `log_likelihood` - Log-likelihood of the model (for model comparison)
188/// - `aic` - Akaike Information Criterion (lower = better)
189/// - `bic` - Bayesian Information Criterion (lower = better)
190///
191/// # Example
192///
193/// ```
194/// # use linreg_core::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
195/// # use linreg_core::linalg::Matrix;
196/// # let y = vec![2.0, 4.0, 6.0, 8.0];
197/// # let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
198/// # let options = ElasticNetOptions { lambda: 0.1, alpha: 0.5, intercept: true, standardize: true, ..Default::default() };
199/// let fit = elastic_net_fit(&x, &y, &options).unwrap();
200///
201/// // Access fit results
202/// println!("Lambda: {}, Alpha: {}", fit.lambda, fit.alpha);
203/// println!("Non-zero coefficients: {}", fit.n_nonzero);
204/// println!("Converged: {}", fit.converged);
205/// println!("R²: {}", fit.r_squared);
206/// println!("AIC: {}", fit.aic);
207/// # Ok::<(), linreg_core::Error>(())
208/// ```
209#[derive(Clone, Debug)]
210#[cfg_attr(feature = "wasm", derive(Serialize))]
211pub struct ElasticNetFit {
212 pub lambda: f64,
213 pub alpha: f64,
214 pub intercept: f64,
215 pub coefficients: Vec<f64>,
216 pub fitted_values: Vec<f64>,
217 pub residuals: Vec<f64>,
218 pub n_nonzero: usize,
219 pub iterations: usize,
220 pub converged: bool,
221 pub r_squared: f64,
222 pub adj_r_squared: f64,
223 pub mse: f64,
224 pub rmse: f64,
225 pub mae: f64,
226 pub log_likelihood: f64,
227 pub aic: f64,
228 pub bic: f64,
229}
230
231use crate::regularized::path::{make_lambda_path, LambdaPathOptions};
232
233/// Fits an elastic net regularization path.
234///
235/// This is the most efficient way to fit models for multiple lambda values.
236/// It performs data standardization once and uses warm starts to speed up
237/// convergence along the path.
238///
239/// # Arguments
240///
241/// * `x` - Design matrix
242/// * `y` - Response vector
243/// * `path_options` - Options for generating the lambda path
244/// * `fit_options` - Options for the elastic net fit (alpha, tol, etc.)
245///
246/// # Returns
247///
248/// A vector of `ElasticNetFit` structs, one for each lambda in the path.
249///
250/// # Example
251///
252/// ```
253/// # use linreg_core::regularized::elastic_net::{elastic_net_path, ElasticNetOptions};
254/// # use linreg_core::regularized::path::LambdaPathOptions;
255/// # use linreg_core::linalg::Matrix;
256/// let y = vec![2.0, 4.0, 6.0, 8.0];
257/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
258///
259/// let path_options = LambdaPathOptions {
260/// nlambda: 10,
261/// ..Default::default()
262/// };
263/// let fit_options = ElasticNetOptions {
264/// alpha: 0.5,
265/// ..Default::default()
266/// };
267///
268/// let path = elastic_net_path(&x, &y, &path_options, &fit_options).unwrap();
269/// assert_eq!(path.len(), 10); // One fit per lambda
270///
271/// // First model has strongest regularization (fewest non-zero coefficients)
272/// println!("Non-zero at lambda_max: {}", path[0].n_nonzero);
273/// // Last model has weakest regularization (most non-zero coefficients)
274/// println!("Non-zero at lambda_min: {}", path.last().unwrap().n_nonzero);
275/// # Ok::<(), linreg_core::Error>(())
276/// ```
277pub fn elastic_net_path(
278 x: &Matrix,
279 y: &[f64],
280 path_options: &LambdaPathOptions,
281 fit_options: &ElasticNetOptions,
282) -> Result<Vec<ElasticNetFit>> {
283 let n = x.rows;
284 let p = x.cols;
285
286 if y.len() != n {
287 return Err(Error::DimensionMismatch(format!(
288 "Length of y ({}) must match number of rows in X ({})",
289 y.len(), n
290 )));
291 }
292
293 // 1. Standardize X and y ONCE
294 let standardization_options = StandardizeOptions {
295 intercept: fit_options.intercept,
296 standardize_x: fit_options.standardize,
297 standardize_y: fit_options.intercept,
298 weights: fit_options.weights.clone(),
299 };
300
301 let (x_standardized, y_standardized, standardization_info) = standardize_xy(x, y, &standardization_options);
302
303 // 2. Generate lambda path
304 // If lambdas are not provided in options (which they aren't in LambdaPathOptions,
305 // it just controls generation), we generate them.
306 // NOTE: If the user wants specific lambdas, they should probably use a different API
307 // or we could add `lambdas: Option<&[f64]>` to this function.
308 // For now, we strictly generate them.
309
310 // We need to account for penalty factors in lambda generation if provided
311 let intercept_col = if fit_options.intercept { Some(0) } else { None };
312 let lambdas = make_lambda_path(
313 &x_standardized,
314 &y_standardized, // y_standardized is centered if intercept=true
315 path_options,
316 fit_options.penalty_factor.as_deref(),
317 intercept_col
318 );
319
320 // 3. Loop over lambdas with warm starts
321 let mut fits = Vec::with_capacity(lambdas.len());
322 let mut coefficients_standardized = vec![0.0; p]; // Initialize at 0
323
324 // Determine unpenalized columns
325 let first_penalized_column_index = if fit_options.intercept { 1 } else { 0 };
326
327 // Calculate scale factor for converting Internal lambdas to Public (user-facing) lambdas
328 // make_lambda_path returns Internal lambdas (for standardized data)
329 // We use these directly in the solver, but scale them for user reporting
330 let y_scale_factor = standardization_info.y_scale.unwrap_or(1.0);
331 // Public lambda = Internal lambda * y_scale_factor
332 // This converts from standardized scale to original data scale
333 let lambda_conversion_factor = if y_scale_factor > 1e-12 {
334 y_scale_factor
335 } else {
336 1.0
337 };
338
339 for &lambda_standardized_value in &lambdas {
340 // The path generation returns lambdas on the internal scale (for standardized data),
341 // which are used directly in coordinate descent without additional scaling.
342 let lambda_standardized = lambda_standardized_value;
343
344 // Transform coefficient bounds to standardized scale
345 // Bounds on original scale need to be converted: coefficients_standardized = beta_orig * x_scale / y_scale
346 let bounds_standardized: Option<Vec<(f64, f64)>> = fit_options.coefficient_bounds.as_ref().map(|bounds| {
347 let y_scale = standardization_info.y_scale.unwrap_or(1.0);
348 bounds.iter().enumerate().map(|(j, &(lower, upper))| {
349 // For each predictor j in original scale, the corresponding column
350 // in the standardized matrix is at index j+1 (col 0 is intercept)
351 let std_idx = j + 1;
352 let x_scale_predictor_j = if std_idx < standardization_info.x_scale.len() {
353 standardization_info.x_scale[std_idx]
354 } else {
355 1.0
356 };
357 let scale_factor = x_scale_predictor_j / y_scale;
358 (lower * scale_factor, upper * scale_factor)
359 }).collect()
360 });
361
362 let (iterations, converged) = coordinate_descent(
363 &x_standardized,
364 &y_standardized,
365 &mut coefficients_standardized,
366 lambda_standardized,
367 fit_options.alpha,
368 first_penalized_column_index,
369 fit_options.max_iter,
370 fit_options.tol,
371 fit_options.penalty_factor.as_deref(),
372 bounds_standardized.as_deref(),
373 &standardization_info.column_squared_norms,
374 )?;
375
376 // Unstandardize coefficients for output
377 let (intercept, beta_orig) = unstandardize_coefficients(&coefficients_standardized, &standardization_info);
378
379 // Count non-zeros
380 let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
381
382 // Fitted values & residuals
383 let fitted = predict(x, intercept, &beta_orig);
384 let residuals: Vec<f64> = y.iter().zip(&fitted).map(|(yi, yh)| yi - yh).collect();
385
386 // Statistics
387 let y_mean = y.iter().sum::<f64>() / n as f64;
388 let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
389 let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
390 let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
391
392 let r_squared = if ss_tot > 1e-10 { 1.0 - ss_res / ss_tot } else { 1.0 };
393 let eff_df = 1.0 + n_nonzero as f64;
394 let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
395 1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
396 } else {
397 r_squared
398 };
399 let mse = ss_res / (n as f64 - eff_df).max(1.0);
400
401 // Model selection criteria
402 let ll = log_likelihood(n, mse, ss_res);
403 let n_coef = beta_orig.len() + 1; // coefficients + intercept
404 let aic_val = aic(ll, n_coef);
405 let bic_val = bic(ll, n_coef, n);
406
407 // Convert Internal lambda to Public (user-facing) lambda for reporting
408 // Public = Internal * y_scale_var * n (to match R's glmnet reporting)
409 let lambda_original_scale = lambda_standardized_value * lambda_conversion_factor;
410
411 fits.push(ElasticNetFit {
412 lambda: lambda_original_scale,
413 alpha: fit_options.alpha,
414 intercept,
415 coefficients: beta_orig,
416 fitted_values: fitted,
417 residuals,
418 n_nonzero,
419 iterations,
420 converged,
421 r_squared,
422 adj_r_squared,
423 mse,
424 rmse: mse.sqrt(),
425 mae,
426 log_likelihood: ll,
427 aic: aic_val,
428 bic: bic_val,
429 });
430 }
431
432 Ok(fits)
433}
434
435/// Fits elastic net regression for a single (lambda, alpha) pair.
436///
437/// Elastic net combines L1 (Lasso) and L2 (Ridge) penalties:
438/// - `alpha = 1.0` is pure Lasso (L1 only)
439/// - `alpha = 0.0` is pure Ridge (L2 only)
440/// - `alpha = 0.5` is an equal mix
441///
442/// # Arguments
443///
444/// * `x` - Design matrix (n rows × p columns including intercept)
445/// * `y` - Response variable (n observations)
446/// * `options` - Configuration options for elastic net regression
447///
448/// # Returns
449///
450/// An `ElasticNetFit` containing coefficients, convergence info, and metrics.
451///
452/// # Example
453///
454/// ```
455/// # use linreg_core::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
456/// # use linreg_core::linalg::Matrix;
457/// let y = vec![2.0, 4.0, 6.0, 8.0];
458/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
459///
460/// // Elastic net with 50% L1, 50% L2
461/// let options = ElasticNetOptions {
462/// lambda: 0.1,
463/// alpha: 0.5,
464/// intercept: true,
465/// standardize: true,
466/// ..Default::default()
467/// };
468///
469/// let fit = elastic_net_fit(&x, &y, &options).unwrap();
470/// assert!(fit.converged);
471/// println!("R²: {}", fit.r_squared);
472/// # Ok::<(), linreg_core::Error>(())
473/// ```
474pub fn elastic_net_fit(x: &Matrix, y: &[f64], options: &ElasticNetOptions) -> Result<ElasticNetFit> {
475 if options.lambda < 0.0 {
476 return Err(Error::InvalidInput("Lambda must be non-negative".into()));
477 }
478 if options.alpha < 0.0 || options.alpha > 1.0 {
479 return Err(Error::InvalidInput("Alpha must be between 0 and 1".into()));
480 }
481
482 let n = x.rows;
483 let p = x.cols;
484
485 if y.len() != n {
486 return Err(Error::DimensionMismatch(format!(
487 "Length of y ({}) must match number of rows in X ({})",
488 y.len(),
489 n
490 )));
491 }
492
493 // Validate coefficient bounds
494 let n_predictors = if options.intercept { p - 1 } else { p };
495 if let Some(ref bounds) = options.coefficient_bounds {
496 if bounds.len() != n_predictors {
497 return Err(Error::InvalidInput(format!(
498 "Coefficient bounds length ({}) must match number of predictors ({})",
499 bounds.len(), n_predictors
500 )));
501 }
502 for (i, &(lower, upper)) in bounds.iter().enumerate() {
503 if lower > upper {
504 return Err(Error::InvalidInput(format!(
505 "Coefficient bounds for predictor {}: lower ({}) must be <= upper ({})",
506 i, lower, upper
507 )));
508 }
509 // Note: We allow (-inf, +inf) as it represents "no bounds" for that predictor
510 // This is useful for having mixed bounded/unbounded predictors
511 }
512 }
513
514 // Standardize X and y
515 // glmnet convention: y is always centered/scaled if intercept is present
516 let standardization_options = StandardizeOptions {
517 intercept: options.intercept,
518 standardize_x: options.standardize,
519 standardize_y: options.intercept,
520 weights: options.weights.clone(),
521 };
522
523 let (x_standardized, y_standardized, standardization_info) = standardize_xy(x, y, &standardization_options);
524
525 // Adjust lambda for scaling
526 // The path generation returns internal lambdas (for standardized data),
527 // which are used directly in coordinate descent.
528 //
529 // For single-lambda fits, the user provides "public" lambda values
530 // (like R reports), which need to be converted to "internal" scale:
531 // lambda_standardized_value = lambda_original_scale / y_scale
532 let y_scale_factor = standardization_info.y_scale.unwrap_or(1.0);
533 let lambda_standardized = if y_scale_factor > 1e-12 {
534 options.lambda / y_scale_factor
535 } else {
536 options.lambda
537 };
538
539 // DEBUG: Print scaling info
540 // #[cfg(debug_assertions)]
541 // {
542 // eprintln!("DEBUG elastic_net_fit: user_lambda = {}, y_scale = {}, lambda_standardized = {}",
543 // options.lambda, y_scale_factor, lambda_standardized);
544 // }
545
546 // Initial coefficients (all zeros)
547 let mut coefficients_standardized = vec![0.0; p];
548
549 // Determine unpenalized columns (e.g. intercept column 0 if manually added,
550 // but standardize_xy handles the intercept externally usually.
551 // If intercept=true, standardize_xy centers data and we don't penalize an implicit intercept.
552 // Here we assume x contains PREDICTORS only if intercept is handled by standardization centering.
553 // However, the `Matrix` struct might include a column of 1s if the user passed it.
554 // `standardize_xy` treats all columns in X as predictors to be standardized.
555 // If options.intercept is true, we compute the intercept from the means later.
556 // We assume X passed here does NOT contain a manual intercept column of 1s unless
557 // the user explicitly wants to penalize it (which is weird) or turned off intercept in options.
558 // For now, we penalize all columns in X according to penalty_factors.
559
560 // Check if we assume X has an intercept column at 0 that we should skip?
561 // The previous ridge/lasso implementations had a `first_penalized_column_index` logic:
562 // `let first_penalized_column_index = if options.intercept { 1 } else { 0 };`
563 // This implies `x` might have a column of 1s.
564 // GLMNET convention usually takes x matrix of predictors only.
565 // `standardize_xy` calculates means for ALL columns.
566 // If column 0 is all 1s, std dev is 0, standardization might fail or set to 0.
567 // Let's stick to the previous `lasso.rs` logic: if intercept is requested, we ignore column 0?
568 // `lasso.rs`: "Determine which columns are penalized. first_penalized_column_index = if options.intercept { 1 } else { 0 }"
569 // This strongly suggests the input Matrix `x` is expected to have a column of 1s at index 0 if intercept=true.
570 // We will preserve this behavior for compatibility with existing tests.
571 // i.e. this is going to be hell to refactor and I'm idly typing my thoughts away...
572 // This is a naive implementation anyways and only one head of the hydra that is glmnet.
573 let first_penalized_column_index = if options.intercept { 1 } else { 0 };
574
575 // Warm start initialization
576 if let Some(warm) = &options.warm_start {
577 // warm contains slope coefficients on ORIGINAL scale
578 // We need to transform them to STANDARDIZED scale
579 // coefficients_standardized = beta_orig * x_scale / y_scale
580 let y_scale = standardization_info.y_scale.unwrap_or(1.0);
581
582 if first_penalized_column_index == 1 {
583 // Case 1: Intercept at col 0
584 // warm start vector should correspond to cols 1..p (slopes)
585 // coefficients_standardized[0] stays 0.0 (intercept of centered data is 0)
586 if warm.len() == p - 1 {
587 for j in 1..p {
588 coefficients_standardized[j] = warm[j - 1] * standardization_info.x_scale[j] / y_scale;
589 }
590 } else {
591 // If dimensions don't match, ignore warm start or warn?
592 // For safety in this "todo" fix, we'll just ignore mismatched warm starts to avoid panics,
593 // but usually this indicates a caller error.
594 // Given I can't print warnings easily here, I'll ignore or maybe assume warm includes intercept?
595 // If warm has length p, maybe it includes intercept? But ElasticNetFit.coefficients excludes it.
596 // Let's stick to: warm start matches slopes.
597 }
598 } else {
599 // Case 2: No intercept column
600 if warm.len() == p {
601 for j in 0..p {
602 coefficients_standardized[j] = warm[j] * standardization_info.x_scale[j] / y_scale;
603 }
604 }
605 }
606 }
607
608 // Transform coefficient bounds to standardized scale
609 // Bounds on original scale need to be converted: coefficients_standardized = beta_orig * x_scale / y_scale
610 let bounds_standardized: Option<Vec<(f64, f64)>> = options.coefficient_bounds.as_ref().map(|bounds| {
611 let y_scale = standardization_info.y_scale.unwrap_or(1.0);
612 bounds.iter().enumerate().map(|(j, &(lower, upper))| {
613 // For each predictor j in original scale, the corresponding column
614 // in the standardized matrix is at index j+1 (col 0 is intercept)
615 let std_idx = j + 1;
616 let x_scale_predictor_j = if std_idx < standardization_info.x_scale.len() {
617 standardization_info.x_scale[std_idx]
618 } else {
619 1.0
620 };
621 let scale_factor = x_scale_predictor_j / y_scale;
622 (lower * scale_factor, upper * scale_factor)
623 }).collect()
624 });
625
626 let (iterations, converged) = coordinate_descent(
627 &x_standardized,
628 &y_standardized,
629 &mut coefficients_standardized,
630 lambda_standardized,
631 options.alpha,
632 first_penalized_column_index,
633 options.max_iter,
634 options.tol,
635 options.penalty_factor.as_deref(),
636 bounds_standardized.as_deref(),
637 &standardization_info.column_squared_norms,
638 )?;
639
640 // Unstandardize
641 let (intercept, beta_orig) = unstandardize_coefficients(&coefficients_standardized, &standardization_info);
642
643 // Count nonzero (excluding intercept)
644 // beta_orig contains slopes. If first_penalized_column_index=1, coefficients_standardized[0] was 0.
645 // The coefficients returned should correspond to the columns of X (excluding the manual intercept if present?).
646 // `unstandardize_coefficients` handles the mapping.
647 let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
648
649 // Fitted values
650 let fitted = predict(x, intercept, &beta_orig);
651 let residuals: Vec<f64> = y.iter().zip(&fitted).map(|(yi, yh)| yi - yh).collect();
652
653 // Statistics
654 let y_mean = y.iter().sum::<f64>() / n as f64;
655 let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
656 let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
657 let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
658
659 let r_squared = if ss_tot > 1e-10 { 1.0 - ss_res / ss_tot } else { 1.0 };
660
661 // Effective DF approximation for Elastic Net
662 // df ≈ n_nonzero for Lasso
663 // df ≈ trace(S) for Ridge
664 // We use a naive approximation here: n_nonzero
665 let eff_df = 1.0 + n_nonzero as f64;
666 let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
667 1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
668 } else {
669 r_squared
670 };
671
672 let mse = ss_res / (n as f64 - eff_df).max(1.0);
673
674 // Model selection criteria
675 let ss_res: f64 = residuals.iter().map(|&r| r * r).sum();
676 let ll = log_likelihood(n, mse, ss_res);
677 let n_coef = beta_orig.len() + 1; // coefficients + intercept
678 let aic_val = aic(ll, n_coef);
679 let bic_val = bic(ll, n_coef, n);
680
681 Ok(ElasticNetFit {
682 lambda: options.lambda,
683 alpha: options.alpha,
684 intercept,
685 coefficients: beta_orig,
686 fitted_values: fitted,
687 residuals,
688 n_nonzero,
689 iterations,
690 converged,
691 r_squared,
692 adj_r_squared,
693 mse,
694 rmse: mse.sqrt(),
695 mae,
696 log_likelihood: ll,
697 aic: aic_val,
698 bic: bic_val,
699 })
700}
701
702#[allow(clippy::too_many_arguments)]
703#[allow(clippy::needless_range_loop)]
704fn coordinate_descent(
705 x: &Matrix,
706 y: &[f64],
707 beta: &mut [f64],
708 lambda: f64,
709 alpha: f64,
710 first_penalized_column_index: usize,
711 max_iter: usize,
712 tol: f64,
713 penalty_factor: Option<&[f64]>,
714 bounds: Option<&[(f64, f64)]>,
715 column_squared_norms: &[f64], // Column squared norms (for coordinate descent update)
716) -> Result<(usize, bool)> {
717 let n = x.rows;
718 let p = x.cols;
719
720 // Residuals r = y - Xβ
721 // Initialize with all betas zero -> residuals = y
722 // If y contains infinity/NaN, residuals will too
723 let mut residuals = y.to_vec();
724
725 // Check for non-finite residuals initially - if present, we can't optimize
726 if residuals.iter().any(|r| !r.is_finite()) {
727 return Ok((0, false));
728 }
729
730 // Handle non-zero initial betas (warm starts)
731 for j in 0..p {
732 if beta[j] != 0.0 {
733 for i in 0..n {
734 residuals[i] -= x.get(i, j) * beta[j];
735 }
736 }
737 }
738
739 // Active set: indices of non-zero coefficients
740 let mut active_set = vec![false; p];
741
742 let mut converged = false;
743 let mut iter = 0;
744
745 while iter < max_iter {
746 let mut maximum_coefficient_change = 0.0;
747
748 // --- Full Pass ---
749 for j in first_penalized_column_index..p {
750 if update_feature(j, x, &mut residuals, beta, lambda, alpha, penalty_factor, bounds, column_squared_norms, &mut maximum_coefficient_change) {
751 active_set[j] = true;
752 }
753 }
754 iter += 1;
755
756 if maximum_coefficient_change < tol {
757 converged = true;
758 break;
759 }
760
761 // --- Active Set Loop ---
762 loop {
763 if iter >= max_iter { break; }
764
765 let mut active_set_coefficient_change = 0.0;
766 let mut active_count = 0;
767
768 for j in first_penalized_column_index..p {
769 if active_set[j] {
770 update_feature(j, x, &mut residuals, beta, lambda, alpha, penalty_factor, bounds, column_squared_norms, &mut active_set_coefficient_change);
771 active_count += 1;
772
773 if beta[j] == 0.0 {
774 active_set[j] = false;
775 }
776 }
777 }
778
779 iter += 1;
780
781 if active_set_coefficient_change < tol {
782 break;
783 }
784
785 if active_count == 0 {
786 break;
787 }
788 }
789 }
790
791 Ok((iter, converged))
792}
793
794#[inline]
795#[allow(clippy::too_many_arguments)]
796#[allow(clippy::needless_range_loop)]
797fn update_feature(
798 j: usize,
799 x: &Matrix,
800 residuals: &mut [f64],
801 beta: &mut [f64],
802 lambda: f64,
803 alpha: f64,
804 penalty_factor: Option<&[f64]>,
805 bounds: Option<&[(f64, f64)]>,
806 column_squared_norms: &[f64], // Column squared norms (for coordinate descent update)
807 maximum_coefficient_change: &mut f64
808) -> bool {
809 // Penalty factor
810 let penalty_factor_value = penalty_factor.and_then(|v| v.get(j)).copied().unwrap_or(1.0);
811 if penalty_factor_value == f64::INFINITY {
812 beta[j] = 0.0;
813 return false;
814 }
815
816 let n = x.rows;
817 let coefficient_previous = beta[j];
818
819 // Calculate partial residual correlation (rho)
820 // residuals currently = y - Sum(Xk * beta_k)
821 // We want r_partial = y - Sum_{k!=j}(Xk * beta_k) = residuals + Xj * beta_j
822 // rho = Xj^T * r_partial = Xj^T * residuals + (Xj^T * Xj) * beta_j
823 // where Xj^T * Xj = column_squared_norms[j] (the squared norm of column j after standardization)
824
825 let mut partial_correlation_unscaled = 0.0;
826 for i in 0..n {
827 partial_correlation_unscaled += x.get(i, j) * residuals[i];
828 }
829 // Use column_squared_norms[j] instead of assuming 1.0
830 let rho = partial_correlation_unscaled + column_squared_norms[j] * coefficient_previous;
831
832 // Soft thresholding
833 // Numerator: S(rho, lambda * alpha * penalty_factor_value)
834 let threshold = lambda * alpha * penalty_factor_value;
835 let soft_threshold_result = soft_threshold(rho, threshold);
836
837 // Denominator
838 // Elastic net denominator: column_squared_norms[j] + lambda * (1 - alpha) * penalty_factor_value
839 // This matches glmnet's formula
840 let denominator_with_ridge_penalty = column_squared_norms[j] + lambda * (1.0 - alpha) * penalty_factor_value;
841
842 let mut coefficient_updated = soft_threshold_result / denominator_with_ridge_penalty;
843
844 // Apply coefficient bounds (clamping) if provided
845 // Bounds clamp the calculated value to [lower, upper]
846 if let Some(bounds) = bounds {
847 // bounds[j-1] because bounds is indexed by predictor (excluding intercept)
848 // and j starts at first_penalized_column_index (usually 1 for intercept models)
849 let bounds_idx = j.saturating_sub(1);
850 if let Some((lower, upper)) = bounds.get(bounds_idx) {
851 coefficient_updated = coefficient_updated.max(*lower).min(*upper);
852 }
853 }
854
855 // Update residuals if beta changed
856 if coefficient_updated != coefficient_previous {
857 let coefficient_change = coefficient_updated - coefficient_previous;
858 for i in 0..n {
859 // residuals_new = residuals_old - x_j * coefficient_change
860 residuals[i] -= x.get(i, j) * coefficient_change;
861 }
862 beta[j] = coefficient_updated;
863 *maximum_coefficient_change = maximum_coefficient_change.max(coefficient_change.abs());
864 true // changed
865 } else {
866 false // no change
867 }
868}