Skip to main content

linreg_core/regularized/
elastic_net.rs

1//! Elastic Net regression (L1 + L2 regularized linear regression).
2//!
3//! This module provides a generalized elastic net implementation using cyclical
4//! coordinate descent with soft-thresholding and active set convergence strategies.
5//! It serves as the core engine for both Lasso (`alpha=1.0`) and Ridge (`alpha=0.0`).
6//!
7//! # Objective Function
8//!
9//! Minimizes over `(β₀, β)`:
10//!
11//! ```text
12//! (1/(2n)) * ||y - β₀ - Xβ||² + λ * [ (1-α)||β||₂²/2 + α||β||₁ ]
13//! ```
14//!
15//! Note on scaling: The internal implementation works with standardized data (unit norm columns).
16//! The lambda parameter is adjusted internally to match the scale expected by the formulation above.
17
18use crate::core::{aic, bic, log_likelihood};
19use crate::error::{Error, Result};
20use crate::linalg::Matrix;
21use crate::regularized::preprocess::{
22    predict, standardize_xy, unstandardize_coefficients, StandardizeOptions,
23};
24use crate::serialization::types::ModelType;
25use crate::impl_serialization;
26use serde::{Deserialize, Serialize};
27
28/// Soft-thresholding operator: S(z, γ) = sign(z) * max(|z| - γ, 0)
29///
30/// This is the key operation in Lasso and Elastic Net regression that applies
31/// the L1 penalty, producing sparse solutions by shrinking small values to zero.
32///
33/// # Arguments
34///
35/// * `z` - Input value to be thresholded
36/// * `gamma` - Threshold value (must be non-negative)
37///
38/// # Returns
39///
40/// - `z - gamma` if `z > gamma`
41/// - `z + gamma` if `z < -gamma`
42/// - `0` otherwise (when `|z| <= gamma`)
43///
44/// # Panics
45///
46/// Panics if `gamma` is negative.
47///
48/// # Example
49///
50/// ```
51/// # use linreg_core::regularized::elastic_net::soft_threshold;
52/// // Values above threshold are reduced
53/// assert_eq!(soft_threshold(5.0, 2.0), 3.0);
54///
55/// // Values below threshold are set to zero
56/// assert_eq!(soft_threshold(1.0, 2.0), 0.0);
57///
58/// // Negative values work symmetrically
59/// assert_eq!(soft_threshold(-5.0, 2.0), -3.0);
60/// assert_eq!(soft_threshold(-1.0, 2.0), 0.0);
61/// ```
62#[inline]
63pub fn soft_threshold(z: f64, gamma: f64) -> f64 {
64    if gamma < 0.0 {
65        panic!("Soft threshold gamma must be non-negative");
66    }
67    if z > gamma {
68        z - gamma
69    } else if z < -gamma {
70        z + gamma
71    } else {
72        0.0
73    }
74}
75
76/// Options for elastic net fitting.
77///
78/// Configuration options for elastic net regression, which combines L1 and L2 penalties.
79///
80/// # Fields
81///
82/// - `lambda` - Regularization strength (≥ 0, higher = more regularization)
83/// - `alpha` - Mixing parameter (0 = Ridge, 1 = Lasso, 0.5 = equal mix)
84/// - `intercept` - Whether to include an intercept term
85/// - `standardize` - Whether to standardize predictors to unit variance
86/// - `max_iter` - Maximum coordinate descent iterations
87/// - `tol` - Convergence tolerance on coefficient changes
88/// - `penalty_factor` - Optional per-feature penalty multipliers
89/// - `warm_start` - Optional initial coefficient values for warm starts
90/// - `weights` - Optional observation weights
91/// - `coefficient_bounds` - Optional (lower, upper) bounds for each coefficient
92///
93/// # Example
94///
95/// ```
96/// # use linreg_core::regularized::elastic_net::ElasticNetOptions;
97/// let options = ElasticNetOptions {
98///     lambda: 0.1,
99///     alpha: 0.5,  // Equal mix of L1 and L2
100///     intercept: true,
101///     standardize: true,
102///     ..Default::default()
103/// };
104/// ```
105#[derive(Clone, Debug)]
106pub struct ElasticNetOptions {
107    /// Regularization strength (lambda >= 0)
108    pub lambda: f64,
109    /// Elastic net mixing parameter (0 <= alpha <= 1).
110    /// alpha=1 is Lasso, alpha=0 is Ridge.
111    pub alpha: f64,
112    /// Whether to include an intercept term
113    pub intercept: bool,
114    /// Whether to standardize predictors
115    pub standardize: bool,
116    /// Maximum coordinate descent iterations
117    pub max_iter: usize,
118    /// Convergence tolerance on coefficient changes
119    pub tol: f64,
120    /// Per-feature penalty factors (optional).
121    /// If None, all features have penalty factor 1.0.
122    pub penalty_factor: Option<Vec<f64>>,
123    /// Initial coefficients for warm start (optional).
124    /// If provided, optimization starts from these values instead of zero.
125    /// Used for efficient pathwise coordinate descent.
126    pub warm_start: Option<Vec<f64>>,
127    /// Observation weights (optional).
128    /// If provided, must have length equal to the number of observations.
129    /// Weights are normalized to sum to 1 internally.
130    pub weights: Option<Vec<f64>>,
131    /// Coefficient bounds: (lower, upper) for each predictor.
132    /// If None, uses (-inf, +inf) for all coefficients (no bounds).
133    ///
134    /// The bounds vector length must equal the number of predictors (excluding intercept).
135    /// For each predictor, the coefficient will be clamped to [lower, upper] after
136    /// each coordinate descent update.
137    ///
138    /// # Examples
139    /// * Non-negative least squares: `Some(vec![(0.0, f64::INFINITY); p])`
140    /// * Upper bound only: `Some(vec![(-f64::INFINITY, 10.0); p])`
141    /// * Both bounds: `Some(vec![(-5.0, 5.0); p])`
142    ///
143    /// # Notes
144    /// * Bounds are applied to coefficients on the ORIGINAL scale, not standardized scale
145    /// * The intercept is never bounded
146    /// * Each pair must satisfy `lower <= upper`
147    pub coefficient_bounds: Option<Vec<(f64, f64)>>,
148}
149
150impl Default for ElasticNetOptions {
151    fn default() -> Self {
152        ElasticNetOptions {
153            lambda: 1.0,
154            alpha: 1.0, // Lasso default
155            intercept: true,
156            standardize: true,
157            max_iter: 100000,
158            tol: 1e-7,
159            penalty_factor: None,
160            warm_start: None,
161            weights: None,
162            coefficient_bounds: None,
163        }
164    }
165}
166
167/// Result of an elastic net fit.
168///
169/// Contains the fitted model coefficients, convergence information, and diagnostic metrics.
170///
171/// # Fields
172///
173/// - `lambda` - The regularization strength used
174/// - `alpha` - The elastic net mixing parameter (0 = Ridge, 1 = Lasso)
175/// - `intercept` - Intercept coefficient (never penalized)
176/// - `coefficients` - Slope coefficients (may be sparse for high alpha)
177/// - `fitted_values` - Predicted values on training data
178/// - `residuals` - Residuals (y - fitted_values)
179/// - `n_nonzero` - Number of non-zero coefficients (excluding intercept)
180/// - `iterations` - Number of coordinate descent iterations performed
181/// - `converged` - Whether the algorithm converged
182/// - `r_squared` - Coefficient of determination
183/// - `adj_r_squared` - Adjusted R²
184/// - `mse` - Mean squared error
185/// - `rmse` - Root mean squared error
186/// - `mae` - Mean absolute error
187/// - `log_likelihood` - Log-likelihood of the model (for model comparison)
188/// - `aic` - Akaike Information Criterion (lower = better)
189/// - `bic` - Bayesian Information Criterion (lower = better)
190///
191/// # Example
192///
193/// ```
194/// # use linreg_core::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
195/// # use linreg_core::linalg::Matrix;
196/// # let y = vec![2.0, 4.0, 6.0, 8.0];
197/// # let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
198/// # let options = ElasticNetOptions { lambda: 0.1, alpha: 0.5, intercept: true, standardize: true, ..Default::default() };
199/// let fit = elastic_net_fit(&x, &y, &options).unwrap();
200///
201/// // Access fit results
202/// println!("Lambda: {}, Alpha: {}", fit.lambda, fit.alpha);
203/// println!("Non-zero coefficients: {}", fit.n_nonzero);
204/// println!("Converged: {}", fit.converged);
205/// println!("R²: {}", fit.r_squared);
206/// println!("AIC: {}", fit.aic);
207/// # Ok::<(), linreg_core::Error>(())
208/// ```
209#[derive(Clone, Debug, Serialize, Deserialize)]
210pub struct ElasticNetFit {
211    pub lambda: f64,
212    pub alpha: f64,
213    pub intercept: f64,
214    pub coefficients: Vec<f64>,
215    pub fitted_values: Vec<f64>,
216    pub residuals: Vec<f64>,
217    pub n_nonzero: usize,
218    pub iterations: usize,
219    pub converged: bool,
220    pub r_squared: f64,
221    pub adj_r_squared: f64,
222    pub mse: f64,
223    pub rmse: f64,
224    pub mae: f64,
225    pub log_likelihood: f64,
226    pub aic: f64,
227    pub bic: f64,
228}
229
230use crate::regularized::path::{make_lambda_path, LambdaPathOptions};
231
232/// Fits an elastic net regularization path.
233///
234/// This is the most efficient way to fit models for multiple lambda values.
235/// It performs data standardization once and uses warm starts to speed up
236/// convergence along the path.
237///
238/// # Arguments
239///
240/// * `x` - Design matrix
241/// * `y` - Response vector
242/// * `path_options` - Options for generating the lambda path
243/// * `fit_options` - Options for the elastic net fit (alpha, tol, etc.)
244///
245/// # Returns
246///
247/// A vector of `ElasticNetFit` structs, one for each lambda in the path.
248///
249/// # Example
250///
251/// ```
252/// # use linreg_core::regularized::elastic_net::{elastic_net_path, ElasticNetOptions};
253/// # use linreg_core::regularized::path::LambdaPathOptions;
254/// # use linreg_core::linalg::Matrix;
255/// let y = vec![2.0, 4.0, 6.0, 8.0];
256/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
257///
258/// let path_options = LambdaPathOptions {
259///     nlambda: 10,
260///     ..Default::default()
261/// };
262/// let fit_options = ElasticNetOptions {
263///     alpha: 0.5,
264///     ..Default::default()
265/// };
266///
267/// let path = elastic_net_path(&x, &y, &path_options, &fit_options).unwrap();
268/// assert_eq!(path.len(), 10); // One fit per lambda
269///
270/// // First model has strongest regularization (fewest non-zero coefficients)
271/// println!("Non-zero at lambda_max: {}", path[0].n_nonzero);
272/// // Last model has weakest regularization (most non-zero coefficients)
273/// println!("Non-zero at lambda_min: {}", path.last().unwrap().n_nonzero);
274/// # Ok::<(), linreg_core::Error>(())
275/// ```
276pub fn elastic_net_path(
277    x: &Matrix,
278    y: &[f64],
279    path_options: &LambdaPathOptions,
280    fit_options: &ElasticNetOptions,
281) -> Result<Vec<ElasticNetFit>> {
282    let n = x.rows;
283    let p = x.cols;
284
285    if y.len() != n {
286        return Err(Error::DimensionMismatch(format!(
287            "Length of y ({}) must match number of rows in X ({})",
288            y.len(), n
289        )));
290    }
291
292    // 1. Standardize X and y ONCE
293    let standardization_options = StandardizeOptions {
294        intercept: fit_options.intercept,
295        standardize_x: fit_options.standardize,
296        standardize_y: fit_options.intercept,
297        weights: fit_options.weights.clone(),
298    };
299
300    let (x_standardized, y_standardized, standardization_info) = standardize_xy(x, y, &standardization_options);
301
302    // 2. Generate lambda path
303    // If lambdas are not provided in options (which they aren't in LambdaPathOptions, 
304    // it just controls generation), we generate them.
305    // NOTE: If the user wants specific lambdas, they should probably use a different API
306    // or we could add `lambdas: Option<&[f64]>` to this function.
307    // For now, we strictly generate them.
308    
309    // We need to account for penalty factors in lambda generation if provided
310    let intercept_col = if fit_options.intercept { Some(0) } else { None };
311    let lambdas = make_lambda_path(
312        &x_standardized,
313        &y_standardized, // y_standardized is centered if intercept=true
314        path_options, 
315        fit_options.penalty_factor.as_deref(), 
316        intercept_col
317    );
318
319    // 3. Loop over lambdas with warm starts
320    let mut fits = Vec::with_capacity(lambdas.len());
321    let mut coefficients_standardized = vec![0.0; p]; // Initialize at 0
322
323    // Determine unpenalized columns
324    let first_penalized_column_index = if fit_options.intercept { 1 } else { 0 };
325
326    // Calculate scale factor for converting Internal lambdas to Public (user-facing) lambdas
327    // make_lambda_path returns Internal lambdas (for standardized data)
328    // We use these directly in the solver, but scale them for user reporting
329    let y_scale_factor = standardization_info.y_scale.unwrap_or(1.0);
330    // Public lambda = Internal lambda * y_scale_factor
331    // This converts from standardized scale to original data scale
332    let lambda_conversion_factor = if y_scale_factor > 1e-12 {
333        y_scale_factor
334    } else {
335        1.0
336    };
337
338    for &lambda_standardized_value in &lambdas {
339        // The path generation returns lambdas on the internal scale (for standardized data),
340        // which are used directly in coordinate descent without additional scaling.
341        let lambda_standardized = lambda_standardized_value;
342
343        // Transform coefficient bounds to standardized scale
344        // Bounds on original scale need to be converted: coefficients_standardized = beta_orig * x_scale / y_scale
345        let bounds_standardized: Option<Vec<(f64, f64)>> = fit_options.coefficient_bounds.as_ref().map(|bounds| {
346            let y_scale = standardization_info.y_scale.unwrap_or(1.0);
347            bounds.iter().enumerate().map(|(j, &(lower, upper))| {
348                // For each predictor j in original scale, the corresponding column
349                // in the standardized matrix is at index j+1 (col 0 is intercept)
350                let std_idx = j + 1;
351                let x_scale_predictor_j = if std_idx < standardization_info.x_scale.len() {
352                    standardization_info.x_scale[std_idx]
353                } else {
354                    1.0
355                };
356                let scale_factor = x_scale_predictor_j / y_scale;
357                (lower * scale_factor, upper * scale_factor)
358            }).collect()
359        });
360
361        let (iterations, converged) = coordinate_descent(
362            &x_standardized,
363            &y_standardized,
364            &mut coefficients_standardized,
365            lambda_standardized,
366            fit_options.alpha,
367            first_penalized_column_index,
368            fit_options.max_iter,
369            fit_options.tol,
370            fit_options.penalty_factor.as_deref(),
371            bounds_standardized.as_deref(),
372            &standardization_info.column_squared_norms,
373        )?;
374
375        // Unstandardize coefficients for output
376        let (intercept, beta_orig) = unstandardize_coefficients(&coefficients_standardized, &standardization_info);
377
378        // Count non-zeros
379        let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
380
381        // Fitted values & residuals
382        let fitted = predict(x, intercept, &beta_orig);
383        let residuals: Vec<f64> = y.iter().zip(&fitted).map(|(yi, yh)| yi - yh).collect();
384
385        // Statistics
386        let y_mean = y.iter().sum::<f64>() / n as f64;
387        let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
388        let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
389        let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
390
391        let r_squared = if ss_tot > 1e-10 { 1.0 - ss_res / ss_tot } else { 1.0 };
392        let eff_df = 1.0 + n_nonzero as f64;
393        let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
394            1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
395        } else {
396            r_squared
397        };
398        let mse = ss_res / (n as f64 - eff_df).max(1.0);
399
400        // Model selection criteria
401        let ll = log_likelihood(n, mse, ss_res);
402        let n_coef = beta_orig.len() + 1; // coefficients + intercept
403        let aic_val = aic(ll, n_coef);
404        let bic_val = bic(ll, n_coef, n);
405
406        // Convert Internal lambda to Public (user-facing) lambda for reporting
407        // Public = Internal * y_scale_var * n (to match R's glmnet reporting)
408        let lambda_original_scale = lambda_standardized_value * lambda_conversion_factor;
409
410        fits.push(ElasticNetFit {
411            lambda: lambda_original_scale,
412            alpha: fit_options.alpha,
413            intercept,
414            coefficients: beta_orig,
415            fitted_values: fitted,
416            residuals,
417            n_nonzero,
418            iterations,
419            converged,
420            r_squared,
421            adj_r_squared,
422            mse,
423            rmse: mse.sqrt(),
424            mae,
425            log_likelihood: ll,
426            aic: aic_val,
427            bic: bic_val,
428        });
429    }
430
431    Ok(fits)
432}
433
434/// Fits elastic net regression for a single (lambda, alpha) pair.
435///
436/// Elastic net combines L1 (Lasso) and L2 (Ridge) penalties:
437/// - `alpha = 1.0` is pure Lasso (L1 only)
438/// - `alpha = 0.0` is pure Ridge (L2 only)
439/// - `alpha = 0.5` is an equal mix
440///
441/// # Arguments
442///
443/// * `x` - Design matrix (n rows × p columns including intercept)
444/// * `y` - Response variable (n observations)
445/// * `options` - Configuration options for elastic net regression
446///
447/// # Returns
448///
449/// An `ElasticNetFit` containing coefficients, convergence info, and metrics.
450///
451/// # Example
452///
453/// ```
454/// # use linreg_core::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
455/// # use linreg_core::linalg::Matrix;
456/// let y = vec![2.0, 4.0, 6.0, 8.0];
457/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
458///
459/// // Elastic net with 50% L1, 50% L2
460/// let options = ElasticNetOptions {
461///     lambda: 0.1,
462///     alpha: 0.5,
463///     intercept: true,
464///     standardize: true,
465///     ..Default::default()
466/// };
467///
468/// let fit = elastic_net_fit(&x, &y, &options).unwrap();
469/// assert!(fit.converged);
470/// println!("R²: {}", fit.r_squared);
471/// # Ok::<(), linreg_core::Error>(())
472/// ```
473///
474/// # Errors
475///
476/// Returns `Error::InvalidInput` if:
477/// - `lambda` is negative
478/// - `alpha` is not in [0, 1]
479///
480/// Returns `Error::InsufficientData` if `x.rows() <= x.cols()`.
481/// Returns `Error::SingularMatrix` if the design matrix is singular.
482///
483/// # Panics
484///
485/// Panics if `x.cols()` is 0 (no predictors including intercept).
486pub fn elastic_net_fit(x: &Matrix, y: &[f64], options: &ElasticNetOptions) -> Result<ElasticNetFit> {
487    if options.lambda < 0.0 {
488        return Err(Error::InvalidInput("Lambda must be non-negative".into()));
489    }
490    if options.alpha < 0.0 || options.alpha > 1.0 {
491        return Err(Error::InvalidInput("Alpha must be between 0 and 1".into()));
492    }
493
494    let n = x.rows;
495    let p = x.cols;
496
497    if y.len() != n {
498        return Err(Error::DimensionMismatch(format!(
499            "Length of y ({}) must match number of rows in X ({})",
500            y.len(),
501            n
502        )));
503    }
504
505    // Validate coefficient bounds
506    let n_predictors = if options.intercept { p - 1 } else { p };
507    if let Some(ref bounds) = options.coefficient_bounds {
508        if bounds.len() != n_predictors {
509            return Err(Error::InvalidInput(format!(
510                "Coefficient bounds length ({}) must match number of predictors ({})",
511                bounds.len(), n_predictors
512            )));
513        }
514        for (i, &(lower, upper)) in bounds.iter().enumerate() {
515            if lower > upper {
516                return Err(Error::InvalidInput(format!(
517                    "Coefficient bounds for predictor {}: lower ({}) must be <= upper ({})",
518                    i, lower, upper
519                )));
520            }
521            // Note: We allow (-inf, +inf) as it represents "no bounds" for that predictor
522            // This is useful for having mixed bounded/unbounded predictors
523        }
524    }
525
526    // Standardize X and y
527    // glmnet convention: y is always centered/scaled if intercept is present
528    let standardization_options = StandardizeOptions {
529        intercept: options.intercept,
530        standardize_x: options.standardize,
531        standardize_y: options.intercept,
532        weights: options.weights.clone(),
533    };
534
535    let (x_standardized, y_standardized, standardization_info) = standardize_xy(x, y, &standardization_options);
536
537    // Adjust lambda for scaling
538    // The path generation returns internal lambdas (for standardized data),
539    // which are used directly in coordinate descent.
540    //
541    // For single-lambda fits, the user provides "public" lambda values
542    // (like R reports), which need to be converted to "internal" scale:
543    //   lambda_standardized_value = lambda_original_scale / y_scale
544    let y_scale_factor = standardization_info.y_scale.unwrap_or(1.0);
545    let lambda_standardized = if y_scale_factor > 1e-12 {
546        options.lambda / y_scale_factor
547    } else {
548        options.lambda
549    };
550
551    // DEBUG: Print scaling info
552    // #[cfg(debug_assertions)]
553    // {
554    //     eprintln!("DEBUG elastic_net_fit: user_lambda = {}, y_scale = {}, lambda_standardized = {}",
555    //              options.lambda, y_scale_factor, lambda_standardized);
556    // }
557
558    // Initial coefficients (all zeros)
559    let mut coefficients_standardized = vec![0.0; p];
560
561    // Determine unpenalized columns (e.g. intercept column 0 if manually added,
562    // but standardize_xy handles the intercept externally usually.
563    // If intercept=true, standardize_xy centers data and we don't penalize an implicit intercept.
564    // Here we assume x contains PREDICTORS only if intercept is handled by standardization centering.
565    // However, the `Matrix` struct might include a column of 1s if the user passed it.
566    // `standardize_xy` treats all columns in X as predictors to be standardized.
567    // If options.intercept is true, we compute the intercept from the means later.
568    // We assume X passed here does NOT contain a manual intercept column of 1s unless
569    // the user explicitly wants to penalize it (which is weird) or turned off intercept in options.
570    // For now, we penalize all columns in X according to penalty_factors.
571
572    // Check if we assume X has an intercept column at 0 that we should skip?
573    // The previous ridge/lasso implementations had a `first_penalized_column_index` logic:
574    // `let first_penalized_column_index = if options.intercept { 1 } else { 0 };`
575    // This implies `x` might have a column of 1s.
576    // GLMNET convention usually takes x matrix of predictors only.
577    // `standardize_xy` calculates means for ALL columns.
578    // If column 0 is all 1s, std dev is 0, standardization might fail or set to 0.
579    // Let's stick to the previous `lasso.rs` logic: if intercept is requested, we ignore column 0?
580    // `lasso.rs`: "Determine which columns are penalized. first_penalized_column_index = if options.intercept { 1 } else { 0 }"
581    // This strongly suggests the input Matrix `x` is expected to have a column of 1s at index 0 if intercept=true.
582    // We will preserve this behavior for compatibility with existing tests.
583    // i.e. this is going to be hell to refactor and I'm idly typing my thoughts away...
584    // This is a naive implementation anyways and only one head of the hydra that is glmnet.
585    let first_penalized_column_index = if options.intercept { 1 } else { 0 };
586
587    // Warm start initialization
588    if let Some(warm) = &options.warm_start {
589        // warm contains slope coefficients on ORIGINAL scale
590        // We need to transform them to STANDARDIZED scale
591        // coefficients_standardized = beta_orig * x_scale / y_scale
592        let y_scale = standardization_info.y_scale.unwrap_or(1.0);
593
594        if first_penalized_column_index == 1 {
595            // Case 1: Intercept at col 0
596            // warm start vector should correspond to cols 1..p (slopes)
597            // coefficients_standardized[0] stays 0.0 (intercept of centered data is 0)
598            if warm.len() == p - 1 {
599                for j in 1..p {
600                    coefficients_standardized[j] = warm[j - 1] * standardization_info.x_scale[j] / y_scale;
601                }
602            } else {
603                // If dimensions don't match, ignore warm start or warn?
604                // For safety in this "todo" fix, we'll just ignore mismatched warm starts to avoid panics,
605                // but usually this indicates a caller error.
606                // Given I can't print warnings easily here, I'll ignore or maybe assume warm includes intercept?
607                // If warm has length p, maybe it includes intercept? But ElasticNetFit.coefficients excludes it.
608                // Let's stick to: warm start matches slopes.
609            }
610        } else {
611            // Case 2: No intercept column
612            if warm.len() == p {
613                for j in 0..p {
614                    coefficients_standardized[j] = warm[j] * standardization_info.x_scale[j] / y_scale;
615                }
616            }
617        }
618    }
619
620    // Transform coefficient bounds to standardized scale
621    // Bounds on original scale need to be converted: coefficients_standardized = beta_orig * x_scale / y_scale
622    let bounds_standardized: Option<Vec<(f64, f64)>> = options.coefficient_bounds.as_ref().map(|bounds| {
623        let y_scale = standardization_info.y_scale.unwrap_or(1.0);
624        bounds.iter().enumerate().map(|(j, &(lower, upper))| {
625            // For each predictor j in original scale, the corresponding column
626            // in the standardized matrix is at index j+1 (col 0 is intercept)
627            let std_idx = j + 1;
628            let x_scale_predictor_j = if std_idx < standardization_info.x_scale.len() {
629                standardization_info.x_scale[std_idx]
630            } else {
631                1.0
632            };
633            let scale_factor = x_scale_predictor_j / y_scale;
634            (lower * scale_factor, upper * scale_factor)
635        }).collect()
636    });
637
638    let (iterations, converged) = coordinate_descent(
639        &x_standardized,
640        &y_standardized,
641        &mut coefficients_standardized,
642        lambda_standardized,
643        options.alpha,
644        first_penalized_column_index,
645        options.max_iter,
646        options.tol,
647        options.penalty_factor.as_deref(),
648        bounds_standardized.as_deref(),
649        &standardization_info.column_squared_norms,
650    )?;
651
652    // Unstandardize
653    let (intercept, beta_orig) = unstandardize_coefficients(&coefficients_standardized, &standardization_info);
654
655    // Count nonzero (excluding intercept)
656    // beta_orig contains slopes. If first_penalized_column_index=1, coefficients_standardized[0] was 0.
657    // The coefficients returned should correspond to the columns of X (excluding the manual intercept if present?).
658    // `unstandardize_coefficients` handles the mapping.
659    let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
660
661    // Fitted values
662    let fitted = predict(x, intercept, &beta_orig);
663    let residuals: Vec<f64> = y.iter().zip(&fitted).map(|(yi, yh)| yi - yh).collect();
664
665    // Statistics
666    let y_mean = y.iter().sum::<f64>() / n as f64;
667    let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
668    let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
669    let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
670
671    let r_squared = if ss_tot > 1e-10 { 1.0 - ss_res / ss_tot } else { 1.0 };
672
673    // Effective DF approximation for Elastic Net
674    // df ≈ n_nonzero for Lasso
675    // df ≈ trace(S) for Ridge
676    // We use a naive approximation here: n_nonzero
677    let eff_df = 1.0 + n_nonzero as f64;
678    let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
679        1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
680    } else {
681        r_squared
682    };
683
684    let mse = ss_res / (n as f64 - eff_df).max(1.0);
685
686    // Model selection criteria
687    let ss_res: f64 = residuals.iter().map(|&r| r * r).sum();
688    let ll = log_likelihood(n, mse, ss_res);
689    let n_coef = beta_orig.len() + 1; // coefficients + intercept
690    let aic_val = aic(ll, n_coef);
691    let bic_val = bic(ll, n_coef, n);
692
693    Ok(ElasticNetFit {
694        lambda: options.lambda,
695        alpha: options.alpha,
696        intercept,
697        coefficients: beta_orig,
698        fitted_values: fitted,
699        residuals,
700        n_nonzero,
701        iterations,
702        converged,
703        r_squared,
704        adj_r_squared,
705        mse,
706        rmse: mse.sqrt(),
707        mae,
708        log_likelihood: ll,
709        aic: aic_val,
710        bic: bic_val,
711    })
712}
713
714#[allow(clippy::too_many_arguments)]
715#[allow(clippy::needless_range_loop)]
716fn coordinate_descent(
717    x: &Matrix,
718    y: &[f64],
719    beta: &mut [f64],
720    lambda: f64,
721    alpha: f64,
722    first_penalized_column_index: usize,
723    max_iter: usize,
724    tol: f64,
725    penalty_factor: Option<&[f64]>,
726    bounds: Option<&[(f64, f64)]>,
727    column_squared_norms: &[f64],  // Column squared norms (for coordinate descent update)
728) -> Result<(usize, bool)> {
729    let n = x.rows;
730    let p = x.cols;
731
732    // Residuals r = y - Xβ
733    // Initialize with all betas zero -> residuals = y
734    // If y contains infinity/NaN, residuals will too
735    let mut residuals = y.to_vec();
736
737    // Check for non-finite residuals initially - if present, we can't optimize
738    if residuals.iter().any(|r| !r.is_finite()) {
739        return Ok((0, false));
740    }
741
742    // Handle non-zero initial betas (warm starts)
743    for j in 0..p {
744        if beta[j] != 0.0 {
745            for i in 0..n {
746                residuals[i] -= x.get(i, j) * beta[j];
747            }
748        }
749    }
750
751    // Active set: indices of non-zero coefficients
752    let mut active_set = vec![false; p];
753
754    let mut converged = false;
755    let mut iter = 0;
756
757    while iter < max_iter {
758        let mut maximum_coefficient_change = 0.0;
759
760        // --- Full Pass ---
761        for j in first_penalized_column_index..p {
762            if update_feature(j, x, &mut residuals, beta, lambda, alpha, penalty_factor, bounds, column_squared_norms, &mut maximum_coefficient_change) {
763                active_set[j] = true;
764            }
765        }
766        iter += 1;
767
768        if maximum_coefficient_change < tol {
769            converged = true;
770            break;
771        }
772
773        // --- Active Set Loop ---
774        loop {
775            if iter >= max_iter { break; }
776
777            let mut active_set_coefficient_change = 0.0;
778            let mut active_count = 0;
779
780            for j in first_penalized_column_index..p {
781                if active_set[j] {
782                    update_feature(j, x, &mut residuals, beta, lambda, alpha, penalty_factor, bounds, column_squared_norms, &mut active_set_coefficient_change);
783                    active_count += 1;
784
785                    if beta[j] == 0.0 {
786                       active_set[j] = false;
787                    }
788                }
789            }
790
791            iter += 1;
792
793            if active_set_coefficient_change < tol {
794                break;
795            }
796
797            if active_count == 0 {
798                break;
799            }
800        }
801    }
802
803    Ok((iter, converged))
804}
805
806#[inline]
807#[allow(clippy::too_many_arguments)]
808#[allow(clippy::needless_range_loop)]
809fn update_feature(
810    j: usize,
811    x: &Matrix,
812    residuals: &mut [f64],
813    beta: &mut [f64],
814    lambda: f64,
815    alpha: f64,
816    penalty_factor: Option<&[f64]>,
817    bounds: Option<&[(f64, f64)]>,
818    column_squared_norms: &[f64],  // Column squared norms (for coordinate descent update)
819    maximum_coefficient_change: &mut f64
820) -> bool {
821    // Penalty factor
822    let penalty_factor_value = penalty_factor.and_then(|v| v.get(j)).copied().unwrap_or(1.0);
823    if penalty_factor_value == f64::INFINITY {
824        beta[j] = 0.0;
825        return false;
826    }
827
828    let n = x.rows;
829    let coefficient_previous = beta[j];
830
831    // Calculate partial residual correlation (rho)
832    // residuals currently = y - Sum(Xk * beta_k)
833    // We want r_partial = y - Sum_{k!=j}(Xk * beta_k) = residuals + Xj * beta_j
834    // rho = Xj^T * r_partial = Xj^T * residuals + (Xj^T * Xj) * beta_j
835    // where Xj^T * Xj = column_squared_norms[j] (the squared norm of column j after standardization)
836
837    let mut partial_correlation_unscaled = 0.0;
838    for i in 0..n {
839        partial_correlation_unscaled += x.get(i, j) * residuals[i];
840    }
841    // Use column_squared_norms[j] instead of assuming 1.0
842    let rho = partial_correlation_unscaled + column_squared_norms[j] * coefficient_previous;
843
844    // Soft thresholding
845    // Numerator: S(rho, lambda * alpha * penalty_factor_value)
846    let threshold = lambda * alpha * penalty_factor_value;
847    let soft_threshold_result = soft_threshold(rho, threshold);
848
849    // Denominator
850    // Elastic net denominator: column_squared_norms[j] + lambda * (1 - alpha) * penalty_factor_value
851    // This matches glmnet's formula
852    let denominator_with_ridge_penalty = column_squared_norms[j] + lambda * (1.0 - alpha) * penalty_factor_value;
853
854    let mut coefficient_updated = soft_threshold_result / denominator_with_ridge_penalty;
855
856    // Apply coefficient bounds (clamping) if provided
857    // Bounds clamp the calculated value to [lower, upper]
858    if let Some(bounds) = bounds {
859        // bounds[j-1] because bounds is indexed by predictor (excluding intercept)
860        // and j starts at first_penalized_column_index (usually 1 for intercept models)
861        let bounds_idx = j.saturating_sub(1);
862        if let Some((lower, upper)) = bounds.get(bounds_idx) {
863            coefficient_updated = coefficient_updated.max(*lower).min(*upper);
864        }
865    }
866
867    // Update residuals if beta changed
868    if coefficient_updated != coefficient_previous {
869        let coefficient_change = coefficient_updated - coefficient_previous;
870        for i in 0..n {
871            // residuals_new = residuals_old - x_j * coefficient_change
872            residuals[i] -= x.get(i, j) * coefficient_change;
873        }
874        beta[j] = coefficient_updated;
875        *maximum_coefficient_change = maximum_coefficient_change.max(coefficient_change.abs());
876        true // changed
877    } else {
878        false // no change
879    }
880}
881
882// ============================================================================
883// Model Serialization Traits
884// ============================================================================
885
886// Generate ModelSave and ModelLoad implementations using macro
887impl_serialization!(ElasticNetFit, ModelType::ElasticNet, "ElasticNet");
888
889#[cfg(test)]
890mod tests {
891    use super::*;
892
893    #[test]
894    fn test_soft_threshold_basic_cases() {
895        // Test soft_threshold function edge cases
896        assert_eq!(soft_threshold(5.0, 2.0), 3.0); // z > gamma
897        assert_eq!(soft_threshold(-5.0, 2.0), -3.0); // z < -gamma
898        assert_eq!(soft_threshold(1.0, 2.0), 0.0); // |z| <= gamma
899        assert_eq!(soft_threshold(2.0, 2.0), 0.0); // z == gamma
900        assert_eq!(soft_threshold(-2.0, 2.0), 0.0); // z == -gamma
901    }
902
903    #[test]
904    fn test_soft_threshold_zero() {
905        assert_eq!(soft_threshold(0.0, 0.0), 0.0);
906        assert_eq!(soft_threshold(5.0, 0.0), 5.0);
907        assert_eq!(soft_threshold(-5.0, 0.0), -5.0);
908    }
909
910    #[test]
911    #[should_panic(expected = "Soft threshold gamma must be non-negative")]
912    fn test_soft_threshold_negative_gamma_panics() {
913        soft_threshold(1.0, -1.0);
914    }
915
916    #[test]
917    fn test_elastic_net_options_default() {
918        let options = ElasticNetOptions::default();
919        assert_eq!(options.lambda, 1.0);
920        assert_eq!(options.alpha, 1.0);  // Default is 1.0 (Lasso)
921        assert!(options.intercept);
922        assert!(options.standardize);
923        assert_eq!(options.max_iter, 100000);
924        assert_eq!(options.tol, 1e-7);
925        assert!(options.penalty_factor.is_none());
926        assert!(options.warm_start.is_none());
927        assert!(options.coefficient_bounds.is_none());
928    }
929
930    #[test]
931    fn test_elastic_net_fit_simple() {
932        // Simple linear relationship: y = 2*x + 1
933        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
934        let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
935
936        // Build matrix with intercept column
937        let n = 5;
938        let p = 1;
939        let mut x_data = vec![1.0; n * (p + 1)];  // Start with all 1s for intercept
940        for i in 0..n {
941            x_data[i * (p + 1) + 1] = x1[i];  // Fill in predictor column
942        }
943        let x = Matrix::new(n, p + 1, x_data);
944
945        let options = ElasticNetOptions {
946            lambda: 0.01,  // Small lambda for minimal regularization
947            alpha: 0.5,
948            intercept: true,
949            standardize: true,
950            ..Default::default()
951        };
952
953        let result = elastic_net_fit(&x, &y, &options);
954        assert!(result.is_ok());
955
956        let fit = result.unwrap();
957        assert!(fit.converged);
958        // Coefficients should be close to [1, 2] (intercept, slope)
959        assert!((fit.intercept - 1.0).abs() < 0.5);
960        assert!((fit.coefficients[0] - 2.0).abs() < 0.5);
961    }
962
963    #[test]
964    fn test_elastic_net_fit_with_penalty_factor() {
965        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
966        let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
967
968        let n = 5;
969        let p = 1;
970        let mut x_data = vec![1.0; n * (p + 1)];
971        for i in 0..n {
972            x_data[i * (p + 1) + 1] = x1[i];
973        }
974        let x = Matrix::new(n, p + 1, x_data);
975
976        let options = ElasticNetOptions {
977            lambda: 0.1,
978            alpha: 0.5,
979            penalty_factor: Some(vec![1.0]),
980            intercept: true,
981            standardize: true,
982            ..Default::default()
983        };
984
985        let result = elastic_net_fit(&x, &y, &options);
986        assert!(result.is_ok());
987    }
988
989    #[test]
990    fn test_elastic_net_fit_with_coefficient_bounds() {
991        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
992        let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
993
994        let n = 5;
995        let p = 1;
996        let mut x_data = vec![1.0; n * (p + 1)];
997        for i in 0..n {
998            x_data[i * (p + 1) + 1] = x1[i];
999        }
1000        let x = Matrix::new(n, p + 1, x_data);
1001
1002        let options = ElasticNetOptions {
1003            lambda: 0.01,
1004            alpha: 0.5,
1005            coefficient_bounds: Some(vec![(0.0, 3.0)]), // Bound slope to [0, 3]
1006            intercept: true,
1007            standardize: true,
1008            ..Default::default()
1009        };
1010
1011        let result = elastic_net_fit(&x, &y, &options);
1012        assert!(result.is_ok());
1013
1014        let fit = result.unwrap();
1015        // Coefficient should be within bounds
1016        assert!(fit.coefficients[0] >= 0.0);
1017        assert!(fit.coefficients[0] <= 3.0);
1018    }
1019
1020    #[test]
1021    fn test_elastic_net_pure_lasso() {
1022        // alpha = 1.0 means pure Lasso
1023        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
1024        let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
1025
1026        let n = 5;
1027        let p = 1;
1028        let mut x_data = vec![1.0; n * (p + 1)];
1029        for i in 0..n {
1030            x_data[i * (p + 1) + 1] = x1[i];
1031        }
1032        let x = Matrix::new(n, p + 1, x_data);
1033
1034        let options = ElasticNetOptions {
1035            lambda: 1.0,
1036            alpha: 1.0,  // Pure Lasso
1037            intercept: true,
1038            standardize: true,
1039            ..Default::default()
1040        };
1041
1042        let result = elastic_net_fit(&x, &y, &options);
1043        assert!(result.is_ok());
1044    }
1045
1046    #[test]
1047    fn test_elastic_net_pure_ridge() {
1048        // alpha = 0.0 means pure Ridge
1049        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
1050        let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
1051
1052        let n = 5;
1053        let p = 1;
1054        let mut x_data = vec![1.0; n * (p + 1)];
1055        for i in 0..n {
1056            x_data[i * (p + 1) + 1] = x1[i];
1057        }
1058        let x = Matrix::new(n, p + 1, x_data);
1059
1060        let options = ElasticNetOptions {
1061            lambda: 0.1,
1062            alpha: 0.0,  // Pure Ridge
1063            intercept: true,
1064            standardize: true,
1065            ..Default::default()
1066        };
1067
1068        let result = elastic_net_fit(&x, &y, &options);
1069        assert!(result.is_ok());
1070
1071        let fit = result.unwrap();
1072        // Ridge shouldn't zero out coefficients
1073        assert!(fit.n_nonzero >= 1);
1074    }
1075
1076    #[test]
1077    fn test_elastic_fit_no_intercept() {
1078        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1079        let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
1080
1081        let n = 5;
1082        let p = 1;
1083        let x = Matrix::new(n, p, x1);  // No intercept column
1084
1085        let options = ElasticNetOptions {
1086            lambda: 0.01,
1087            alpha: 0.5,
1088            intercept: false,  // No intercept
1089            standardize: true,
1090            ..Default::default()
1091        };
1092
1093        let result = elastic_net_fit(&x, &y, &options);
1094        assert!(result.is_ok());
1095    }
1096
1097    #[test]
1098    fn test_elastic_net_with_warm_start() {
1099        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
1100        let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
1101
1102        let n = 5;
1103        let p = 1;
1104        let mut x_data = vec![1.0; n * (p + 1)];
1105        for i in 0..n {
1106            x_data[i * (p + 1) + 1] = x1[i];
1107        }
1108        let x = Matrix::new(n, p + 1, x_data);
1109
1110        let warm = vec![1.5];
1111
1112        let options = ElasticNetOptions {
1113            lambda: 0.1,
1114            alpha: 0.5,
1115            intercept: true,
1116            standardize: true,
1117            warm_start: Some(warm),
1118            ..Default::default()
1119        };
1120
1121        let result = elastic_net_fit(&x, &y, &options);
1122        assert!(result.is_ok());
1123    }
1124
1125    #[test]
1126    fn test_elastic_net_multivariate() {
1127        // Multiple predictors
1128        let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
1129        let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
1130        let x2 = vec![2.0, 4.0, 5.0, 4.0, 3.0];
1131
1132        let n = 5;
1133        let p = 2;
1134        let mut x_data = vec![1.0; n * (p + 1)];  // Intercept column
1135        for i in 0..n {
1136            x_data[i * (p + 1) + 1] = x1[i];
1137            x_data[i * (p + 1) + 2] = x2[i];
1138        }
1139        let x = Matrix::new(n, p + 1, x_data);
1140
1141        let options = ElasticNetOptions {
1142            lambda: 0.1,
1143            alpha: 0.5,
1144            intercept: true,
1145            standardize: true,
1146            ..Default::default()
1147        };
1148
1149        let result = elastic_net_fit(&x, &y, &options);
1150        assert!(result.is_ok());
1151
1152        let fit = result.unwrap();
1153        assert_eq!(fit.coefficients.len(), 2); // Two predictors
1154    }
1155}