use crate::core::{aic, bic, log_likelihood};
use crate::error::{Error, Result};
use crate::linalg::Matrix;
use crate::regularized::preprocess::{
predict, standardize_xy, unstandardize_coefficients, StandardizeOptions,
};
use crate::serialization::types::ModelType;
use crate::impl_serialization;
use serde::{Deserialize, Serialize};
#[inline]
pub fn soft_threshold(z: f64, gamma: f64) -> f64 {
if gamma < 0.0 {
panic!("Soft threshold gamma must be non-negative");
}
if z > gamma {
z - gamma
} else if z < -gamma {
z + gamma
} else {
0.0
}
}
#[derive(Clone, Debug)]
pub struct ElasticNetOptions {
pub lambda: f64,
pub alpha: f64,
pub intercept: bool,
pub standardize: bool,
pub max_iter: usize,
pub tol: f64,
pub penalty_factor: Option<Vec<f64>>,
pub warm_start: Option<Vec<f64>>,
pub weights: Option<Vec<f64>>,
pub coefficient_bounds: Option<Vec<(f64, f64)>>,
}
impl Default for ElasticNetOptions {
fn default() -> Self {
ElasticNetOptions {
lambda: 1.0,
alpha: 1.0, intercept: true,
standardize: true,
max_iter: 100000,
tol: 1e-7,
penalty_factor: None,
warm_start: None,
weights: None,
coefficient_bounds: None,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ElasticNetFit {
pub lambda: f64,
pub alpha: f64,
pub intercept: f64,
pub coefficients: Vec<f64>,
pub fitted_values: Vec<f64>,
pub residuals: Vec<f64>,
pub n_nonzero: usize,
pub iterations: usize,
pub converged: bool,
pub r_squared: f64,
pub adj_r_squared: f64,
pub mse: f64,
pub rmse: f64,
pub mae: f64,
pub log_likelihood: f64,
pub aic: f64,
pub bic: f64,
}
use crate::regularized::path::{make_lambda_path, LambdaPathOptions};
pub fn elastic_net_path(
x: &Matrix,
y: &[f64],
path_options: &LambdaPathOptions,
fit_options: &ElasticNetOptions,
) -> Result<Vec<ElasticNetFit>> {
let n = x.rows;
let p = x.cols;
if y.len() != n {
return Err(Error::DimensionMismatch(format!(
"Length of y ({}) must match number of rows in X ({})",
y.len(), n
)));
}
let standardization_options = StandardizeOptions {
intercept: fit_options.intercept,
standardize_x: fit_options.standardize,
standardize_y: fit_options.intercept,
weights: fit_options.weights.clone(),
};
let (x_standardized, y_standardized, standardization_info) = standardize_xy(x, y, &standardization_options);
let intercept_col = if fit_options.intercept { Some(0) } else { None };
let lambdas = make_lambda_path(
&x_standardized,
&y_standardized, path_options,
fit_options.penalty_factor.as_deref(),
intercept_col
);
let mut fits = Vec::with_capacity(lambdas.len());
let mut coefficients_standardized = vec![0.0; p];
let first_penalized_column_index = if fit_options.intercept { 1 } else { 0 };
let y_scale_factor = standardization_info.y_scale.unwrap_or(1.0);
let lambda_conversion_factor = if y_scale_factor > 1e-12 {
y_scale_factor
} else {
1.0
};
for &lambda_standardized_value in &lambdas {
let lambda_standardized = lambda_standardized_value;
let bounds_standardized: Option<Vec<(f64, f64)>> = fit_options.coefficient_bounds.as_ref().map(|bounds| {
let y_scale = standardization_info.y_scale.unwrap_or(1.0);
bounds.iter().enumerate().map(|(j, &(lower, upper))| {
let std_idx = j + 1;
let x_scale_predictor_j = if std_idx < standardization_info.x_scale.len() {
standardization_info.x_scale[std_idx]
} else {
1.0
};
let scale_factor = x_scale_predictor_j / y_scale;
(lower * scale_factor, upper * scale_factor)
}).collect()
});
let (iterations, converged) = coordinate_descent(
&x_standardized,
&y_standardized,
&mut coefficients_standardized,
lambda_standardized,
fit_options.alpha,
first_penalized_column_index,
fit_options.max_iter,
fit_options.tol,
fit_options.penalty_factor.as_deref(),
bounds_standardized.as_deref(),
&standardization_info.column_squared_norms,
)?;
let (intercept, beta_orig) = unstandardize_coefficients(&coefficients_standardized, &standardization_info);
let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
let fitted = predict(x, intercept, &beta_orig);
let residuals: Vec<f64> = y.iter().zip(&fitted).map(|(yi, yh)| yi - yh).collect();
let y_mean = y.iter().sum::<f64>() / n as f64;
let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
let r_squared = if ss_tot > 1e-10 { 1.0 - ss_res / ss_tot } else { 1.0 };
let eff_df = 1.0 + n_nonzero as f64;
let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
} else {
r_squared
};
let mse = ss_res / (n as f64 - eff_df).max(1.0);
let ll = log_likelihood(n, mse, ss_res);
let n_coef = beta_orig.len() + 1; let aic_val = aic(ll, n_coef);
let bic_val = bic(ll, n_coef, n);
let lambda_original_scale = lambda_standardized_value * lambda_conversion_factor;
fits.push(ElasticNetFit {
lambda: lambda_original_scale,
alpha: fit_options.alpha,
intercept,
coefficients: beta_orig,
fitted_values: fitted,
residuals,
n_nonzero,
iterations,
converged,
r_squared,
adj_r_squared,
mse,
rmse: mse.sqrt(),
mae,
log_likelihood: ll,
aic: aic_val,
bic: bic_val,
});
}
Ok(fits)
}
pub fn elastic_net_fit(x: &Matrix, y: &[f64], options: &ElasticNetOptions) -> Result<ElasticNetFit> {
if options.lambda < 0.0 {
return Err(Error::InvalidInput("Lambda must be non-negative".into()));
}
if options.alpha < 0.0 || options.alpha > 1.0 {
return Err(Error::InvalidInput("Alpha must be between 0 and 1".into()));
}
let n = x.rows;
let p = x.cols;
if y.len() != n {
return Err(Error::DimensionMismatch(format!(
"Length of y ({}) must match number of rows in X ({})",
y.len(),
n
)));
}
let n_predictors = if options.intercept { p - 1 } else { p };
if let Some(ref bounds) = options.coefficient_bounds {
if bounds.len() != n_predictors {
return Err(Error::InvalidInput(format!(
"Coefficient bounds length ({}) must match number of predictors ({})",
bounds.len(), n_predictors
)));
}
for (i, &(lower, upper)) in bounds.iter().enumerate() {
if lower > upper {
return Err(Error::InvalidInput(format!(
"Coefficient bounds for predictor {}: lower ({}) must be <= upper ({})",
i, lower, upper
)));
}
}
}
let standardization_options = StandardizeOptions {
intercept: options.intercept,
standardize_x: options.standardize,
standardize_y: options.intercept,
weights: options.weights.clone(),
};
let (x_standardized, y_standardized, standardization_info) = standardize_xy(x, y, &standardization_options);
let y_scale_factor = standardization_info.y_scale.unwrap_or(1.0);
let lambda_standardized = if y_scale_factor > 1e-12 {
options.lambda / y_scale_factor
} else {
options.lambda
};
let mut coefficients_standardized = vec![0.0; p];
let first_penalized_column_index = if options.intercept { 1 } else { 0 };
if let Some(warm) = &options.warm_start {
let y_scale = standardization_info.y_scale.unwrap_or(1.0);
if first_penalized_column_index == 1 {
if warm.len() == p - 1 {
for j in 1..p {
coefficients_standardized[j] = warm[j - 1] * standardization_info.x_scale[j] / y_scale;
}
} else {
}
} else {
if warm.len() == p {
for j in 0..p {
coefficients_standardized[j] = warm[j] * standardization_info.x_scale[j] / y_scale;
}
}
}
}
let bounds_standardized: Option<Vec<(f64, f64)>> = options.coefficient_bounds.as_ref().map(|bounds| {
let y_scale = standardization_info.y_scale.unwrap_or(1.0);
bounds.iter().enumerate().map(|(j, &(lower, upper))| {
let std_idx = j + 1;
let x_scale_predictor_j = if std_idx < standardization_info.x_scale.len() {
standardization_info.x_scale[std_idx]
} else {
1.0
};
let scale_factor = x_scale_predictor_j / y_scale;
(lower * scale_factor, upper * scale_factor)
}).collect()
});
let (iterations, converged) = coordinate_descent(
&x_standardized,
&y_standardized,
&mut coefficients_standardized,
lambda_standardized,
options.alpha,
first_penalized_column_index,
options.max_iter,
options.tol,
options.penalty_factor.as_deref(),
bounds_standardized.as_deref(),
&standardization_info.column_squared_norms,
)?;
let (intercept, beta_orig) = unstandardize_coefficients(&coefficients_standardized, &standardization_info);
let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
let fitted = predict(x, intercept, &beta_orig);
let residuals: Vec<f64> = y.iter().zip(&fitted).map(|(yi, yh)| yi - yh).collect();
let y_mean = y.iter().sum::<f64>() / n as f64;
let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
let r_squared = if ss_tot > 1e-10 { 1.0 - ss_res / ss_tot } else { 1.0 };
let eff_df = 1.0 + n_nonzero as f64;
let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
} else {
r_squared
};
let mse = ss_res / (n as f64 - eff_df).max(1.0);
let ss_res: f64 = residuals.iter().map(|&r| r * r).sum();
let ll = log_likelihood(n, mse, ss_res);
let n_coef = beta_orig.len() + 1; let aic_val = aic(ll, n_coef);
let bic_val = bic(ll, n_coef, n);
Ok(ElasticNetFit {
lambda: options.lambda,
alpha: options.alpha,
intercept,
coefficients: beta_orig,
fitted_values: fitted,
residuals,
n_nonzero,
iterations,
converged,
r_squared,
adj_r_squared,
mse,
rmse: mse.sqrt(),
mae,
log_likelihood: ll,
aic: aic_val,
bic: bic_val,
})
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::needless_range_loop)]
fn coordinate_descent(
x: &Matrix,
y: &[f64],
beta: &mut [f64],
lambda: f64,
alpha: f64,
first_penalized_column_index: usize,
max_iter: usize,
tol: f64,
penalty_factor: Option<&[f64]>,
bounds: Option<&[(f64, f64)]>,
column_squared_norms: &[f64], ) -> Result<(usize, bool)> {
let n = x.rows;
let p = x.cols;
let mut residuals = y.to_vec();
if residuals.iter().any(|r| !r.is_finite()) {
return Ok((0, false));
}
for j in 0..p {
if beta[j] != 0.0 {
for i in 0..n {
residuals[i] -= x.get(i, j) * beta[j];
}
}
}
let mut active_set = vec![false; p];
let mut converged = false;
let mut iter = 0;
while iter < max_iter {
let mut maximum_coefficient_change = 0.0;
for j in first_penalized_column_index..p {
if update_feature(j, x, &mut residuals, beta, lambda, alpha, penalty_factor, bounds, column_squared_norms, &mut maximum_coefficient_change) {
active_set[j] = true;
}
}
iter += 1;
if maximum_coefficient_change < tol {
converged = true;
break;
}
loop {
if iter >= max_iter { break; }
let mut active_set_coefficient_change = 0.0;
let mut active_count = 0;
for j in first_penalized_column_index..p {
if active_set[j] {
update_feature(j, x, &mut residuals, beta, lambda, alpha, penalty_factor, bounds, column_squared_norms, &mut active_set_coefficient_change);
active_count += 1;
if beta[j] == 0.0 {
active_set[j] = false;
}
}
}
iter += 1;
if active_set_coefficient_change < tol {
break;
}
if active_count == 0 {
break;
}
}
}
Ok((iter, converged))
}
#[inline]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::needless_range_loop)]
fn update_feature(
j: usize,
x: &Matrix,
residuals: &mut [f64],
beta: &mut [f64],
lambda: f64,
alpha: f64,
penalty_factor: Option<&[f64]>,
bounds: Option<&[(f64, f64)]>,
column_squared_norms: &[f64], maximum_coefficient_change: &mut f64
) -> bool {
let penalty_factor_value = penalty_factor.and_then(|v| v.get(j)).copied().unwrap_or(1.0);
if penalty_factor_value == f64::INFINITY {
beta[j] = 0.0;
return false;
}
let n = x.rows;
let coefficient_previous = beta[j];
let mut partial_correlation_unscaled = 0.0;
for i in 0..n {
partial_correlation_unscaled += x.get(i, j) * residuals[i];
}
let rho = partial_correlation_unscaled + column_squared_norms[j] * coefficient_previous;
let threshold = lambda * alpha * penalty_factor_value;
let soft_threshold_result = soft_threshold(rho, threshold);
let denominator_with_ridge_penalty = column_squared_norms[j] + lambda * (1.0 - alpha) * penalty_factor_value;
let mut coefficient_updated = soft_threshold_result / denominator_with_ridge_penalty;
if let Some(bounds) = bounds {
let bounds_idx = j.saturating_sub(1);
if let Some((lower, upper)) = bounds.get(bounds_idx) {
coefficient_updated = coefficient_updated.max(*lower).min(*upper);
}
}
if coefficient_updated != coefficient_previous {
let coefficient_change = coefficient_updated - coefficient_previous;
for i in 0..n {
residuals[i] -= x.get(i, j) * coefficient_change;
}
beta[j] = coefficient_updated;
*maximum_coefficient_change = maximum_coefficient_change.max(coefficient_change.abs());
true } else {
false }
}
impl_serialization!(ElasticNetFit, ModelType::ElasticNet, "ElasticNet");
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_soft_threshold_basic_cases() {
assert_eq!(soft_threshold(5.0, 2.0), 3.0); assert_eq!(soft_threshold(-5.0, 2.0), -3.0); assert_eq!(soft_threshold(1.0, 2.0), 0.0); assert_eq!(soft_threshold(2.0, 2.0), 0.0); assert_eq!(soft_threshold(-2.0, 2.0), 0.0); }
#[test]
fn test_soft_threshold_zero() {
assert_eq!(soft_threshold(0.0, 0.0), 0.0);
assert_eq!(soft_threshold(5.0, 0.0), 5.0);
assert_eq!(soft_threshold(-5.0, 0.0), -5.0);
}
#[test]
#[should_panic(expected = "Soft threshold gamma must be non-negative")]
fn test_soft_threshold_negative_gamma_panics() {
soft_threshold(1.0, -1.0);
}
#[test]
fn test_elastic_net_options_default() {
let options = ElasticNetOptions::default();
assert_eq!(options.lambda, 1.0);
assert_eq!(options.alpha, 1.0); assert!(options.intercept);
assert!(options.standardize);
assert_eq!(options.max_iter, 100000);
assert_eq!(options.tol, 1e-7);
assert!(options.penalty_factor.is_none());
assert!(options.warm_start.is_none());
assert!(options.coefficient_bounds.is_none());
}
#[test]
fn test_elastic_net_fit_simple() {
let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
let n = 5;
let p = 1;
let mut x_data = vec![1.0; n * (p + 1)]; for i in 0..n {
x_data[i * (p + 1) + 1] = x1[i]; }
let x = Matrix::new(n, p + 1, x_data);
let options = ElasticNetOptions {
lambda: 0.01, alpha: 0.5,
intercept: true,
standardize: true,
..Default::default()
};
let result = elastic_net_fit(&x, &y, &options);
assert!(result.is_ok());
let fit = result.unwrap();
assert!(fit.converged);
assert!((fit.intercept - 1.0).abs() < 0.5);
assert!((fit.coefficients[0] - 2.0).abs() < 0.5);
}
#[test]
fn test_elastic_net_fit_with_penalty_factor() {
let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
let n = 5;
let p = 1;
let mut x_data = vec![1.0; n * (p + 1)];
for i in 0..n {
x_data[i * (p + 1) + 1] = x1[i];
}
let x = Matrix::new(n, p + 1, x_data);
let options = ElasticNetOptions {
lambda: 0.1,
alpha: 0.5,
penalty_factor: Some(vec![1.0]),
intercept: true,
standardize: true,
..Default::default()
};
let result = elastic_net_fit(&x, &y, &options);
assert!(result.is_ok());
}
#[test]
fn test_elastic_net_fit_with_coefficient_bounds() {
let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
let n = 5;
let p = 1;
let mut x_data = vec![1.0; n * (p + 1)];
for i in 0..n {
x_data[i * (p + 1) + 1] = x1[i];
}
let x = Matrix::new(n, p + 1, x_data);
let options = ElasticNetOptions {
lambda: 0.01,
alpha: 0.5,
coefficient_bounds: Some(vec![(0.0, 3.0)]), intercept: true,
standardize: true,
..Default::default()
};
let result = elastic_net_fit(&x, &y, &options);
assert!(result.is_ok());
let fit = result.unwrap();
assert!(fit.coefficients[0] >= 0.0);
assert!(fit.coefficients[0] <= 3.0);
}
#[test]
fn test_elastic_net_pure_lasso() {
let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
let n = 5;
let p = 1;
let mut x_data = vec![1.0; n * (p + 1)];
for i in 0..n {
x_data[i * (p + 1) + 1] = x1[i];
}
let x = Matrix::new(n, p + 1, x_data);
let options = ElasticNetOptions {
lambda: 1.0,
alpha: 1.0, intercept: true,
standardize: true,
..Default::default()
};
let result = elastic_net_fit(&x, &y, &options);
assert!(result.is_ok());
}
#[test]
fn test_elastic_net_pure_ridge() {
let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
let n = 5;
let p = 1;
let mut x_data = vec![1.0; n * (p + 1)];
for i in 0..n {
x_data[i * (p + 1) + 1] = x1[i];
}
let x = Matrix::new(n, p + 1, x_data);
let options = ElasticNetOptions {
lambda: 0.1,
alpha: 0.0, intercept: true,
standardize: true,
..Default::default()
};
let result = elastic_net_fit(&x, &y, &options);
assert!(result.is_ok());
let fit = result.unwrap();
assert!(fit.n_nonzero >= 1);
}
#[test]
fn test_elastic_fit_no_intercept() {
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
let n = 5;
let p = 1;
let x = Matrix::new(n, p, x1);
let options = ElasticNetOptions {
lambda: 0.01,
alpha: 0.5,
intercept: false, standardize: true,
..Default::default()
};
let result = elastic_net_fit(&x, &y, &options);
assert!(result.is_ok());
}
#[test]
fn test_elastic_net_with_warm_start() {
let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
let n = 5;
let p = 1;
let mut x_data = vec![1.0; n * (p + 1)];
for i in 0..n {
x_data[i * (p + 1) + 1] = x1[i];
}
let x = Matrix::new(n, p + 1, x_data);
let warm = vec![1.5];
let options = ElasticNetOptions {
lambda: 0.1,
alpha: 0.5,
intercept: true,
standardize: true,
warm_start: Some(warm),
..Default::default()
};
let result = elastic_net_fit(&x, &y, &options);
assert!(result.is_ok());
}
#[test]
fn test_elastic_net_multivariate() {
let y = vec![3.0, 5.0, 7.0, 9.0, 11.0];
let x1: Vec<f64> = (1..=5).map(|i| i as f64).collect();
let x2 = vec![2.0, 4.0, 5.0, 4.0, 3.0];
let n = 5;
let p = 2;
let mut x_data = vec![1.0; n * (p + 1)]; for i in 0..n {
x_data[i * (p + 1) + 1] = x1[i];
x_data[i * (p + 1) + 2] = x2[i];
}
let x = Matrix::new(n, p + 1, x_data);
let options = ElasticNetOptions {
lambda: 0.1,
alpha: 0.5,
intercept: true,
standardize: true,
..Default::default()
};
let result = elastic_net_fit(&x, &y, &options);
assert!(result.is_ok());
let fit = result.unwrap();
assert_eq!(fit.coefficients.len(), 2); }
}