Skip to main content

linreg_core/regularized/
elastic_net.rs

1//! Elastic Net regression (L1 + L2 regularized linear regression).
2//!
3//! This module provides a generalized elastic net implementation using cyclical
4//! coordinate descent with soft-thresholding and active set convergence strategies.
5//! It serves as the core engine for both Lasso (`alpha=1.0`) and Ridge (`alpha=0.0`).
6//!
7//! # Objective Function
8//!
9//! Minimizes over `(β₀, β)`:
10//!
11//! ```text
12//! (1/(2n)) * ||y - β₀ - Xβ||² + λ * [ (1-α)||β||₂²/2 + α||β||₁ ]
13//! ```
14//!
15//! Note on scaling: The internal implementation works with standardized data (unit norm columns).
16//! The lambda parameter is adjusted internally to match the scale expected by the formulation above.
17
18use crate::error::{Error, Result};
19use crate::linalg::Matrix;
20use crate::regularized::preprocess::{
21    predict, standardize_xy, unstandardize_coefficients, StandardizeOptions,
22};
23
24#[cfg(feature = "wasm")]
25use serde::Serialize;
26
27/// Soft-thresholding operator: S(z, γ) = sign(z) * max(|z| - γ, 0)
28///
29/// This is the key operation in Lasso and Elastic Net regression that applies
30/// the L1 penalty, producing sparse solutions by shrinking small values to zero.
31///
32/// # Arguments
33///
34/// * `z` - Input value to be thresholded
35/// * `gamma` - Threshold value (must be non-negative)
36///
37/// # Returns
38///
39/// - `z - gamma` if `z > gamma`
40/// - `z + gamma` if `z < -gamma`
41/// - `0` otherwise (when `|z| <= gamma`)
42///
43/// # Panics
44///
45/// Panics if `gamma` is negative.
46///
47/// # Example
48///
49/// ```
50/// # use linreg_core::regularized::elastic_net::soft_threshold;
51/// // Values above threshold are reduced
52/// assert_eq!(soft_threshold(5.0, 2.0), 3.0);
53///
54/// // Values below threshold are set to zero
55/// assert_eq!(soft_threshold(1.0, 2.0), 0.0);
56///
57/// // Negative values work symmetrically
58/// assert_eq!(soft_threshold(-5.0, 2.0), -3.0);
59/// assert_eq!(soft_threshold(-1.0, 2.0), 0.0);
60/// ```
61#[inline]
62pub fn soft_threshold(z: f64, gamma: f64) -> f64 {
63    if gamma < 0.0 {
64        panic!("Soft threshold gamma must be non-negative");
65    }
66    if z > gamma {
67        z - gamma
68    } else if z < -gamma {
69        z + gamma
70    } else {
71        0.0
72    }
73}
74
75/// Options for elastic net fitting.
76///
77/// Configuration options for elastic net regression, which combines L1 and L2 penalties.
78///
79/// # Fields
80///
81/// - `lambda` - Regularization strength (≥ 0, higher = more regularization)
82/// - `alpha` - Mixing parameter (0 = Ridge, 1 = Lasso, 0.5 = equal mix)
83/// - `intercept` - Whether to include an intercept term
84/// - `standardize` - Whether to standardize predictors to unit variance
85/// - `max_iter` - Maximum coordinate descent iterations
86/// - `tol` - Convergence tolerance on coefficient changes
87/// - `penalty_factor` - Optional per-feature penalty multipliers
88/// - `warm_start` - Optional initial coefficient values for warm starts
89/// - `weights` - Optional observation weights
90/// - `coefficient_bounds` - Optional (lower, upper) bounds for each coefficient
91///
92/// # Example
93///
94/// ```
95/// # use linreg_core::regularized::elastic_net::ElasticNetOptions;
96/// let options = ElasticNetOptions {
97///     lambda: 0.1,
98///     alpha: 0.5,  // Equal mix of L1 and L2
99///     intercept: true,
100///     standardize: true,
101///     ..Default::default()
102/// };
103/// ```
104#[derive(Clone, Debug)]
105pub struct ElasticNetOptions {
106    /// Regularization strength (lambda >= 0)
107    pub lambda: f64,
108    /// Elastic net mixing parameter (0 <= alpha <= 1).
109    /// alpha=1 is Lasso, alpha=0 is Ridge.
110    pub alpha: f64,
111    /// Whether to include an intercept term
112    pub intercept: bool,
113    /// Whether to standardize predictors
114    pub standardize: bool,
115    /// Maximum coordinate descent iterations
116    pub max_iter: usize,
117    /// Convergence tolerance on coefficient changes
118    pub tol: f64,
119    /// Per-feature penalty factors (optional).
120    /// If None, all features have penalty factor 1.0.
121    pub penalty_factor: Option<Vec<f64>>,
122    /// Initial coefficients for warm start (optional).
123    /// If provided, optimization starts from these values instead of zero.
124    /// Used for efficient pathwise coordinate descent.
125    pub warm_start: Option<Vec<f64>>,
126    /// Observation weights (optional).
127    /// If provided, must have length equal to the number of observations.
128    /// Weights are normalized to sum to 1 internally.
129    pub weights: Option<Vec<f64>>,
130    /// Coefficient bounds: (lower, upper) for each predictor.
131    /// If None, uses (-inf, +inf) for all coefficients (no bounds).
132    ///
133    /// The bounds vector length must equal the number of predictors (excluding intercept).
134    /// For each predictor, the coefficient will be clamped to [lower, upper] after
135    /// each coordinate descent update.
136    ///
137    /// # Examples
138    /// * Non-negative least squares: `Some(vec![(0.0, f64::INFINITY); p])`
139    /// * Upper bound only: `Some(vec![(-f64::INFINITY, 10.0); p])`
140    /// * Both bounds: `Some(vec![(-5.0, 5.0); p])`
141    ///
142    /// # Notes
143    /// * Bounds are applied to coefficients on the ORIGINAL scale, not standardized scale
144    /// * The intercept is never bounded
145    /// * Each pair must satisfy `lower <= upper`
146    pub coefficient_bounds: Option<Vec<(f64, f64)>>,
147}
148
149impl Default for ElasticNetOptions {
150    fn default() -> Self {
151        ElasticNetOptions {
152            lambda: 1.0,
153            alpha: 1.0, // Lasso default
154            intercept: true,
155            standardize: true,
156            max_iter: 100000,
157            tol: 1e-7,
158            penalty_factor: None,
159            warm_start: None,
160            weights: None,
161            coefficient_bounds: None,
162        }
163    }
164}
165
166/// Result of an elastic net fit.
167///
168/// Contains the fitted model coefficients, convergence information, and diagnostic metrics.
169///
170/// # Fields
171///
172/// - `lambda` - The regularization strength used
173/// - `alpha` - The elastic net mixing parameter (0 = Ridge, 1 = Lasso)
174/// - `intercept` - Intercept coefficient (never penalized)
175/// - `coefficients` - Slope coefficients (may be sparse for high alpha)
176/// - `fitted_values` - Predicted values on training data
177/// - `residuals` - Residuals (y - fitted_values)
178/// - `n_nonzero` - Number of non-zero coefficients (excluding intercept)
179/// - `iterations` - Number of coordinate descent iterations performed
180/// - `converged` - Whether the algorithm converged
181/// - `r_squared` - Coefficient of determination
182/// - `adj_r_squared` - Adjusted R²
183/// - `mse` - Mean squared error
184/// - `rmse` - Root mean squared error
185/// - `mae` - Mean absolute error
186///
187/// # Example
188///
189/// ```
190/// # use linreg_core::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
191/// # use linreg_core::linalg::Matrix;
192/// # let y = vec![2.0, 4.0, 6.0, 8.0];
193/// # let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
194/// # let options = ElasticNetOptions { lambda: 0.1, alpha: 0.5, intercept: true, standardize: true, ..Default::default() };
195/// let fit = elastic_net_fit(&x, &y, &options).unwrap();
196///
197/// // Access fit results
198/// println!("Lambda: {}, Alpha: {}", fit.lambda, fit.alpha);
199/// println!("Non-zero coefficients: {}", fit.n_nonzero);
200/// println!("Converged: {}", fit.converged);
201/// println!("R²: {}", fit.r_squared);
202/// # Ok::<(), linreg_core::Error>(())
203/// ```
204#[derive(Clone, Debug)]
205#[cfg_attr(feature = "wasm", derive(Serialize))]
206pub struct ElasticNetFit {
207    pub lambda: f64,
208    pub alpha: f64,
209    pub intercept: f64,
210    pub coefficients: Vec<f64>,
211    pub fitted_values: Vec<f64>,
212    pub residuals: Vec<f64>,
213    pub n_nonzero: usize,
214    pub iterations: usize,
215    pub converged: bool,
216    pub r_squared: f64,
217    pub adj_r_squared: f64,
218    pub mse: f64,
219    pub rmse: f64,
220    pub mae: f64,
221}
222
223use crate::regularized::path::{make_lambda_path, LambdaPathOptions};
224
225/// Fits an elastic net regularization path.
226///
227/// This is the most efficient way to fit models for multiple lambda values.
228/// It performs data standardization once and uses warm starts to speed up
229/// convergence along the path.
230///
231/// # Arguments
232///
233/// * `x` - Design matrix
234/// * `y` - Response vector
235/// * `path_options` - Options for generating the lambda path
236/// * `fit_options` - Options for the elastic net fit (alpha, tol, etc.)
237///
238/// # Returns
239///
240/// A vector of `ElasticNetFit` structs, one for each lambda in the path.
241///
242/// # Example
243///
244/// ```
245/// # use linreg_core::regularized::elastic_net::{elastic_net_path, ElasticNetOptions};
246/// # use linreg_core::regularized::path::LambdaPathOptions;
247/// # use linreg_core::linalg::Matrix;
248/// let y = vec![2.0, 4.0, 6.0, 8.0];
249/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
250///
251/// let path_options = LambdaPathOptions {
252///     nlambda: 10,
253///     ..Default::default()
254/// };
255/// let fit_options = ElasticNetOptions {
256///     alpha: 0.5,
257///     ..Default::default()
258/// };
259///
260/// let path = elastic_net_path(&x, &y, &path_options, &fit_options).unwrap();
261/// assert_eq!(path.len(), 10); // One fit per lambda
262///
263/// // First model has strongest regularization (fewest non-zero coefficients)
264/// println!("Non-zero at lambda_max: {}", path[0].n_nonzero);
265/// // Last model has weakest regularization (most non-zero coefficients)
266/// println!("Non-zero at lambda_min: {}", path.last().unwrap().n_nonzero);
267/// # Ok::<(), linreg_core::Error>(())
268/// ```
269pub fn elastic_net_path(
270    x: &Matrix,
271    y: &[f64],
272    path_options: &LambdaPathOptions,
273    fit_options: &ElasticNetOptions,
274) -> Result<Vec<ElasticNetFit>> {
275    let n = x.rows;
276    let p = x.cols;
277
278    if y.len() != n {
279        return Err(Error::DimensionMismatch(format!(
280            "Length of y ({}) must match number of rows in X ({})",
281            y.len(), n
282        )));
283    }
284
285    // 1. Standardize X and y ONCE
286    let standardization_options = StandardizeOptions {
287        intercept: fit_options.intercept,
288        standardize_x: fit_options.standardize,
289        standardize_y: fit_options.intercept,
290        weights: fit_options.weights.clone(),
291    };
292
293    let (x_standardized, y_standardized, standardization_info) = standardize_xy(x, y, &standardization_options);
294
295    // 2. Generate lambda path
296    // If lambdas are not provided in options (which they aren't in LambdaPathOptions, 
297    // it just controls generation), we generate them.
298    // NOTE: If the user wants specific lambdas, they should probably use a different API
299    // or we could add `lambdas: Option<&[f64]>` to this function.
300    // For now, we strictly generate them.
301    
302    // We need to account for penalty factors in lambda generation if provided
303    let intercept_col = if fit_options.intercept { Some(0) } else { None };
304    let lambdas = make_lambda_path(
305        &x_standardized,
306        &y_standardized, // y_standardized is centered if intercept=true
307        path_options, 
308        fit_options.penalty_factor.as_deref(), 
309        intercept_col
310    );
311
312    // 3. Loop over lambdas with warm starts
313    let mut fits = Vec::with_capacity(lambdas.len());
314    let mut coefficients_standardized = vec![0.0; p]; // Initialize at 0
315
316    // Determine unpenalized columns
317    let first_penalized_column_index = if fit_options.intercept { 1 } else { 0 };
318
319    // Calculate scale factor for converting Internal lambdas to Public (user-facing) lambdas
320    // make_lambda_path returns Internal lambdas (for standardized data)
321    // We use these directly in the solver, but scale them for user reporting
322    let y_scale_factor = standardization_info.y_scale.unwrap_or(1.0);
323    // Public lambda = Internal lambda * y_scale_factor
324    // This converts from standardized scale to original data scale
325    let lambda_conversion_factor = if y_scale_factor > 1e-12 {
326        y_scale_factor
327    } else {
328        1.0
329    };
330
331    for &lambda_standardized_value in &lambdas {
332        // The path generation returns lambdas on the internal scale (for standardized data),
333        // which are used directly in coordinate descent without additional scaling.
334        let lambda_standardized = lambda_standardized_value;
335
336        // Transform coefficient bounds to standardized scale
337        // Bounds on original scale need to be converted: coefficients_standardized = beta_orig * x_scale / y_scale
338        let bounds_standardized: Option<Vec<(f64, f64)>> = fit_options.coefficient_bounds.as_ref().map(|bounds| {
339            let y_scale = standardization_info.y_scale.unwrap_or(1.0);
340            bounds.iter().enumerate().map(|(j, &(lower, upper))| {
341                // For each predictor j in original scale, the corresponding column
342                // in the standardized matrix is at index j+1 (col 0 is intercept)
343                let std_idx = j + 1;
344                let x_scale_predictor_j = if std_idx < standardization_info.x_scale.len() {
345                    standardization_info.x_scale[std_idx]
346                } else {
347                    1.0
348                };
349                let scale_factor = x_scale_predictor_j / y_scale;
350                (lower * scale_factor, upper * scale_factor)
351            }).collect()
352        });
353
354        let (iterations, converged) = coordinate_descent(
355            &x_standardized,
356            &y_standardized,
357            &mut coefficients_standardized,
358            lambda_standardized,
359            fit_options.alpha,
360            first_penalized_column_index,
361            fit_options.max_iter,
362            fit_options.tol,
363            fit_options.penalty_factor.as_deref(),
364            bounds_standardized.as_deref(),
365            &standardization_info.column_squared_norms,
366        )?;
367
368        // Unstandardize coefficients for output
369        let (intercept, beta_orig) = unstandardize_coefficients(&coefficients_standardized, &standardization_info);
370
371        // Count non-zeros
372        let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
373
374        // Fitted values & residuals
375        let fitted = predict(x, intercept, &beta_orig);
376        let residuals: Vec<f64> = y.iter().zip(&fitted).map(|(yi, yh)| yi - yh).collect();
377
378        // Statistics
379        let y_mean = y.iter().sum::<f64>() / n as f64;
380        let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
381        let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
382        let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
383
384        let r_squared = if ss_tot > 1e-10 { 1.0 - ss_res / ss_tot } else { 1.0 };
385        let eff_df = 1.0 + n_nonzero as f64;
386        let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
387            1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
388        } else {
389            r_squared
390        };
391        let mse = ss_res / (n as f64 - eff_df).max(1.0);
392
393        // Convert Internal lambda to Public (user-facing) lambda for reporting
394        // Public = Internal * y_scale_var * n (to match R's glmnet reporting)
395        let lambda_original_scale = lambda_standardized_value * lambda_conversion_factor;
396
397        fits.push(ElasticNetFit {
398            lambda: lambda_original_scale,
399            alpha: fit_options.alpha,
400            intercept,
401            coefficients: beta_orig,
402            fitted_values: fitted,
403            residuals,
404            n_nonzero,
405            iterations,
406            converged,
407            r_squared,
408            adj_r_squared,
409            mse,
410            rmse: mse.sqrt(),
411            mae,
412        });
413    }
414
415    Ok(fits)
416}
417
418/// Fits elastic net regression for a single (lambda, alpha) pair.
419///
420/// Elastic net combines L1 (Lasso) and L2 (Ridge) penalties:
421/// - `alpha = 1.0` is pure Lasso (L1 only)
422/// - `alpha = 0.0` is pure Ridge (L2 only)
423/// - `alpha = 0.5` is an equal mix
424///
425/// # Arguments
426///
427/// * `x` - Design matrix (n rows × p columns including intercept)
428/// * `y` - Response variable (n observations)
429/// * `options` - Configuration options for elastic net regression
430///
431/// # Returns
432///
433/// An `ElasticNetFit` containing coefficients, convergence info, and metrics.
434///
435/// # Example
436///
437/// ```
438/// # use linreg_core::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
439/// # use linreg_core::linalg::Matrix;
440/// let y = vec![2.0, 4.0, 6.0, 8.0];
441/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
442///
443/// // Elastic net with 50% L1, 50% L2
444/// let options = ElasticNetOptions {
445///     lambda: 0.1,
446///     alpha: 0.5,
447///     intercept: true,
448///     standardize: true,
449///     ..Default::default()
450/// };
451///
452/// let fit = elastic_net_fit(&x, &y, &options).unwrap();
453/// assert!(fit.converged);
454/// println!("R²: {}", fit.r_squared);
455/// # Ok::<(), linreg_core::Error>(())
456/// ```
457pub fn elastic_net_fit(x: &Matrix, y: &[f64], options: &ElasticNetOptions) -> Result<ElasticNetFit> {
458    if options.lambda < 0.0 {
459        return Err(Error::InvalidInput("Lambda must be non-negative".into()));
460    }
461    if options.alpha < 0.0 || options.alpha > 1.0 {
462        return Err(Error::InvalidInput("Alpha must be between 0 and 1".into()));
463    }
464
465    let n = x.rows;
466    let p = x.cols;
467
468    if y.len() != n {
469        return Err(Error::DimensionMismatch(format!(
470            "Length of y ({}) must match number of rows in X ({})",
471            y.len(),
472            n
473        )));
474    }
475
476    // Validate coefficient bounds
477    let n_predictors = if options.intercept { p - 1 } else { p };
478    if let Some(ref bounds) = options.coefficient_bounds {
479        if bounds.len() != n_predictors {
480            return Err(Error::InvalidInput(format!(
481                "Coefficient bounds length ({}) must match number of predictors ({})",
482                bounds.len(), n_predictors
483            )));
484        }
485        for (i, &(lower, upper)) in bounds.iter().enumerate() {
486            if lower > upper {
487                return Err(Error::InvalidInput(format!(
488                    "Coefficient bounds for predictor {}: lower ({}) must be <= upper ({})",
489                    i, lower, upper
490                )));
491            }
492            // Note: We allow (-inf, +inf) as it represents "no bounds" for that predictor
493            // This is useful for having mixed bounded/unbounded predictors
494        }
495    }
496
497    // Standardize X and y
498    // glmnet convention: y is always centered/scaled if intercept is present
499    let standardization_options = StandardizeOptions {
500        intercept: options.intercept,
501        standardize_x: options.standardize,
502        standardize_y: options.intercept,
503        weights: options.weights.clone(),
504    };
505
506    let (x_standardized, y_standardized, standardization_info) = standardize_xy(x, y, &standardization_options);
507
508    // Adjust lambda for scaling
509    // The path generation returns internal lambdas (for standardized data),
510    // which are used directly in coordinate descent.
511    //
512    // For single-lambda fits, the user provides "public" lambda values
513    // (like R reports), which need to be converted to "internal" scale:
514    //   lambda_standardized_value = lambda_original_scale / y_scale
515    let y_scale_factor = standardization_info.y_scale.unwrap_or(1.0);
516    let lambda_standardized = if y_scale_factor > 1e-12 {
517        options.lambda / y_scale_factor
518    } else {
519        options.lambda
520    };
521
522    // DEBUG: Print scaling info
523    // #[cfg(debug_assertions)]
524    // {
525    //     eprintln!("DEBUG elastic_net_fit: user_lambda = {}, y_scale = {}, lambda_standardized = {}",
526    //              options.lambda, y_scale_factor, lambda_standardized);
527    // }
528
529    // Initial coefficients (all zeros)
530    let mut coefficients_standardized = vec![0.0; p];
531
532    // Determine unpenalized columns (e.g. intercept column 0 if manually added,
533    // but standardize_xy handles the intercept externally usually.
534    // If intercept=true, standardize_xy centers data and we don't penalize an implicit intercept.
535    // Here we assume x contains PREDICTORS only if intercept is handled by standardization centering.
536    // However, the `Matrix` struct might include a column of 1s if the user passed it.
537    // `standardize_xy` treats all columns in X as predictors to be standardized.
538    // If options.intercept is true, we compute the intercept from the means later.
539    // We assume X passed here does NOT contain a manual intercept column of 1s unless
540    // the user explicitly wants to penalize it (which is weird) or turned off intercept in options.
541    // For now, we penalize all columns in X according to penalty_factors.
542
543    // Check if we assume X has an intercept column at 0 that we should skip?
544    // The previous ridge/lasso implementations had a `first_penalized_column_index` logic:
545    // `let first_penalized_column_index = if options.intercept { 1 } else { 0 };`
546    // This implies `x` might have a column of 1s.
547    // GLMNET convention usually takes x matrix of predictors only.
548    // `standardize_xy` calculates means for ALL columns.
549    // If column 0 is all 1s, std dev is 0, standardization might fail or set to 0.
550    // Let's stick to the previous `lasso.rs` logic: if intercept is requested, we ignore column 0?
551    // `lasso.rs`: "Determine which columns are penalized. first_penalized_column_index = if options.intercept { 1 } else { 0 }"
552    // This strongly suggests the input Matrix `x` is expected to have a column of 1s at index 0 if intercept=true.
553    // We will preserve this behavior for compatibility with existing tests.
554    // i.e. this is going to be hell to refactor and I'm idly typing my thoughts away...
555    // This is a naive implementation anyways and only one head of the hydra that is glmnet.
556    let first_penalized_column_index = if options.intercept { 1 } else { 0 };
557
558    // Warm start initialization
559    if let Some(warm) = &options.warm_start {
560        // warm contains slope coefficients on ORIGINAL scale
561        // We need to transform them to STANDARDIZED scale
562        // coefficients_standardized = beta_orig * x_scale / y_scale
563        let y_scale = standardization_info.y_scale.unwrap_or(1.0);
564
565        if first_penalized_column_index == 1 {
566            // Case 1: Intercept at col 0
567            // warm start vector should correspond to cols 1..p (slopes)
568            // coefficients_standardized[0] stays 0.0 (intercept of centered data is 0)
569            if warm.len() == p - 1 {
570                for j in 1..p {
571                    coefficients_standardized[j] = warm[j - 1] * standardization_info.x_scale[j] / y_scale;
572                }
573            } else {
574                // If dimensions don't match, ignore warm start or warn?
575                // For safety in this "todo" fix, we'll just ignore mismatched warm starts to avoid panics,
576                // but usually this indicates a caller error.
577                // Given I can't print warnings easily here, I'll ignore or maybe assume warm includes intercept?
578                // If warm has length p, maybe it includes intercept? But ElasticNetFit.coefficients excludes it.
579                // Let's stick to: warm start matches slopes.
580            }
581        } else {
582            // Case 2: No intercept column
583            if warm.len() == p {
584                for j in 0..p {
585                    coefficients_standardized[j] = warm[j] * standardization_info.x_scale[j] / y_scale;
586                }
587            }
588        }
589    }
590
591    // Transform coefficient bounds to standardized scale
592    // Bounds on original scale need to be converted: coefficients_standardized = beta_orig * x_scale / y_scale
593    let bounds_standardized: Option<Vec<(f64, f64)>> = options.coefficient_bounds.as_ref().map(|bounds| {
594        let y_scale = standardization_info.y_scale.unwrap_or(1.0);
595        bounds.iter().enumerate().map(|(j, &(lower, upper))| {
596            // For each predictor j in original scale, the corresponding column
597            // in the standardized matrix is at index j+1 (col 0 is intercept)
598            let std_idx = j + 1;
599            let x_scale_predictor_j = if std_idx < standardization_info.x_scale.len() {
600                standardization_info.x_scale[std_idx]
601            } else {
602                1.0
603            };
604            let scale_factor = x_scale_predictor_j / y_scale;
605            (lower * scale_factor, upper * scale_factor)
606        }).collect()
607    });
608
609    let (iterations, converged) = coordinate_descent(
610        &x_standardized,
611        &y_standardized,
612        &mut coefficients_standardized,
613        lambda_standardized,
614        options.alpha,
615        first_penalized_column_index,
616        options.max_iter,
617        options.tol,
618        options.penalty_factor.as_deref(),
619        bounds_standardized.as_deref(),
620        &standardization_info.column_squared_norms,
621    )?;
622
623    // Unstandardize
624    let (intercept, beta_orig) = unstandardize_coefficients(&coefficients_standardized, &standardization_info);
625
626    // Count nonzero (excluding intercept)
627    // beta_orig contains slopes. If first_penalized_column_index=1, coefficients_standardized[0] was 0.
628    // The coefficients returned should correspond to the columns of X (excluding the manual intercept if present?).
629    // `unstandardize_coefficients` handles the mapping.
630    let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
631
632    // Fitted values
633    let fitted = predict(x, intercept, &beta_orig);
634    let residuals: Vec<f64> = y.iter().zip(&fitted).map(|(yi, yh)| yi - yh).collect();
635
636    // Statistics
637    let y_mean = y.iter().sum::<f64>() / n as f64;
638    let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
639    let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
640    let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
641
642    let r_squared = if ss_tot > 1e-10 { 1.0 - ss_res / ss_tot } else { 1.0 };
643
644    // Effective DF approximation for Elastic Net
645    // df ≈ n_nonzero for Lasso
646    // df ≈ trace(S) for Ridge
647    // We use a naive approximation here: n_nonzero
648    let eff_df = 1.0 + n_nonzero as f64;
649    let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
650        1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
651    } else {
652        r_squared
653    };
654
655    let mse = ss_res / (n as f64 - eff_df).max(1.0);
656
657    Ok(ElasticNetFit {
658        lambda: options.lambda,
659        alpha: options.alpha,
660        intercept,
661        coefficients: beta_orig,
662        fitted_values: fitted,
663        residuals,
664        n_nonzero,
665        iterations,
666        converged,
667        r_squared,
668        adj_r_squared,
669        mse,
670        rmse: mse.sqrt(),
671        mae,
672    })
673}
674
675#[allow(clippy::too_many_arguments)]
676#[allow(clippy::needless_range_loop)]
677fn coordinate_descent(
678    x: &Matrix,
679    y: &[f64],
680    beta: &mut [f64],
681    lambda: f64,
682    alpha: f64,
683    first_penalized_column_index: usize,
684    max_iter: usize,
685    tol: f64,
686    penalty_factor: Option<&[f64]>,
687    bounds: Option<&[(f64, f64)]>,
688    column_squared_norms: &[f64],  // Column squared norms (for coordinate descent update)
689) -> Result<(usize, bool)> {
690    let n = x.rows;
691    let p = x.cols;
692
693    // Residuals r = y - Xβ
694    // Initialize with all betas zero -> residuals = y
695    // If y contains infinity/NaN, residuals will too
696    let mut residuals = y.to_vec();
697
698    // Check for non-finite residuals initially - if present, we can't optimize
699    if residuals.iter().any(|r| !r.is_finite()) {
700        return Ok((0, false));
701    }
702
703    // Handle non-zero initial betas (warm starts)
704    for j in 0..p {
705        if beta[j] != 0.0 {
706            for i in 0..n {
707                residuals[i] -= x.get(i, j) * beta[j];
708            }
709        }
710    }
711
712    // Active set: indices of non-zero coefficients
713    let mut active_set = vec![false; p];
714
715    let mut converged = false;
716    let mut iter = 0;
717
718    while iter < max_iter {
719        let mut maximum_coefficient_change = 0.0;
720
721        // --- Full Pass ---
722        for j in first_penalized_column_index..p {
723            if update_feature(j, x, &mut residuals, beta, lambda, alpha, penalty_factor, bounds, column_squared_norms, &mut maximum_coefficient_change) {
724                active_set[j] = true;
725            }
726        }
727        iter += 1;
728
729        if maximum_coefficient_change < tol {
730            converged = true;
731            break;
732        }
733
734        // --- Active Set Loop ---
735        loop {
736            if iter >= max_iter { break; }
737
738            let mut active_set_coefficient_change = 0.0;
739            let mut active_count = 0;
740
741            for j in first_penalized_column_index..p {
742                if active_set[j] {
743                    update_feature(j, x, &mut residuals, beta, lambda, alpha, penalty_factor, bounds, column_squared_norms, &mut active_set_coefficient_change);
744                    active_count += 1;
745
746                    if beta[j] == 0.0 {
747                       active_set[j] = false;
748                    }
749                }
750            }
751
752            iter += 1;
753
754            if active_set_coefficient_change < tol {
755                break;
756            }
757
758            if active_count == 0 {
759                break;
760            }
761        }
762    }
763
764    Ok((iter, converged))
765}
766
767#[inline]
768#[allow(clippy::too_many_arguments)]
769#[allow(clippy::needless_range_loop)]
770fn update_feature(
771    j: usize,
772    x: &Matrix,
773    residuals: &mut [f64],
774    beta: &mut [f64],
775    lambda: f64,
776    alpha: f64,
777    penalty_factor: Option<&[f64]>,
778    bounds: Option<&[(f64, f64)]>,
779    column_squared_norms: &[f64],  // Column squared norms (for coordinate descent update)
780    maximum_coefficient_change: &mut f64
781) -> bool {
782    // Penalty factor
783    let penalty_factor_value = penalty_factor.and_then(|v| v.get(j)).copied().unwrap_or(1.0);
784    if penalty_factor_value == f64::INFINITY {
785        beta[j] = 0.0;
786        return false;
787    }
788
789    let n = x.rows;
790    let coefficient_previous = beta[j];
791
792    // Calculate partial residual correlation (rho)
793    // residuals currently = y - Sum(Xk * beta_k)
794    // We want r_partial = y - Sum_{k!=j}(Xk * beta_k) = residuals + Xj * beta_j
795    // rho = Xj^T * r_partial = Xj^T * residuals + (Xj^T * Xj) * beta_j
796    // where Xj^T * Xj = column_squared_norms[j] (the squared norm of column j after standardization)
797
798    let mut partial_correlation_unscaled = 0.0;
799    for i in 0..n {
800        partial_correlation_unscaled += x.get(i, j) * residuals[i];
801    }
802    // Use column_squared_norms[j] instead of assuming 1.0
803    let rho = partial_correlation_unscaled + column_squared_norms[j] * coefficient_previous;
804
805    // Soft thresholding
806    // Numerator: S(rho, lambda * alpha * penalty_factor_value)
807    let threshold = lambda * alpha * penalty_factor_value;
808    let soft_threshold_result = soft_threshold(rho, threshold);
809
810    // Denominator
811    // Elastic net denominator: column_squared_norms[j] + lambda * (1 - alpha) * penalty_factor_value
812    // This matches glmnet's formula
813    let denominator_with_ridge_penalty = column_squared_norms[j] + lambda * (1.0 - alpha) * penalty_factor_value;
814
815    let mut coefficient_updated = soft_threshold_result / denominator_with_ridge_penalty;
816
817    // Apply coefficient bounds (clamping) if provided
818    // Bounds clamp the calculated value to [lower, upper]
819    if let Some(bounds) = bounds {
820        // bounds[j-1] because bounds is indexed by predictor (excluding intercept)
821        // and j starts at first_penalized_column_index (usually 1 for intercept models)
822        let bounds_idx = j.saturating_sub(1);
823        if let Some((lower, upper)) = bounds.get(bounds_idx) {
824            coefficient_updated = coefficient_updated.max(*lower).min(*upper);
825        }
826    }
827
828    // Update residuals if beta changed
829    if coefficient_updated != coefficient_previous {
830        let coefficient_change = coefficient_updated - coefficient_previous;
831        for i in 0..n {
832            // residuals_new = residuals_old - x_j * coefficient_change
833            residuals[i] -= x.get(i, j) * coefficient_change;
834        }
835        beta[j] = coefficient_updated;
836        *maximum_coefficient_change = maximum_coefficient_change.max(coefficient_change.abs());
837        true // changed
838    } else {
839        false // no change
840    }
841}