Skip to main content

linreg_core/wasm/
regularized.rs

1//! Regularized regression methods for WASM
2//!
3//! Provides WASM bindings for Ridge, Lasso, and Elastic Net regression.
4
5#![cfg(feature = "wasm")]
6
7use wasm_bindgen::prelude::*;
8
9use super::domain::check_domain;
10use crate::error::{error_json, error_to_json};
11use crate::linalg;
12use crate::regularized;
13
14/// Performs Ridge regression via WASM.
15///
16/// Ridge regression adds an L2 penalty to the coefficients, which helps with
17/// multicollinearity and overfitting. The intercept is never penalized.
18///
19/// # Arguments
20///
21/// * `y_json` - JSON array of response variable values
22/// * `x_vars_json` - JSON array of predictor arrays
23/// * `variable_names` - JSON array of variable names
24/// * `lambda` - Regularization strength (>= 0, typical range 0.01 to 100)
25/// * `standardize` - Whether to standardize predictors (recommended: true)
26///
27/// # Returns
28///
29/// JSON string containing:
30/// - `lambda` - Lambda value used
31/// - `intercept` - Intercept coefficient
32/// - `coefficients` - Slope coefficients
33/// - `fitted_values` - Predictions on training data
34/// - `residuals` - Residuals (y - fitted_values)
35/// - `df` - Effective degrees of freedom
36///
37/// # Errors
38///
39/// Returns a JSON error object if parsing fails, lambda is negative,
40/// or domain check fails.
41#[wasm_bindgen]
42pub fn ridge_regression(
43    y_json: &str,
44    x_vars_json: &str,
45    _variable_names: &str,
46    lambda: f64,
47    standardize: bool,
48) -> String {
49    if let Err(e) = check_domain() {
50        return error_to_json(&e);
51    }
52
53    // Parse JSON input
54    let y: Vec<f64> = match serde_json::from_str(y_json) {
55        Ok(v) => v,
56        Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
57    };
58
59    let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
60        Ok(v) => v,
61        Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
62    };
63
64    // Build design matrix with intercept column
65    let (x, n, p) = build_design_matrix(&y, &x_vars);
66
67    if n <= p + 1 {
68        return error_json(&format!(
69            "Insufficient data: need at least {} observations for {} predictors",
70            p + 2,
71            p
72        ));
73    }
74
75    // Configure ridge options
76    let options = regularized::ridge::RidgeFitOptions {
77        lambda,
78        intercept: true,
79        standardize,
80        max_iter: 100000,
81        tol: 1e-7,
82        warm_start: None,
83        weights: None,
84    };
85
86    match regularized::ridge::ridge_fit(&x, &y, &options) {
87        Ok(output) => serde_json::to_string(&output)
88            .unwrap_or_else(|_| error_json("Failed to serialize ridge regression result")),
89        Err(e) => error_json(&e.to_string()),
90    }
91}
92
93/// Performs Lasso regression via WASM.
94///
95/// Lasso regression adds an L1 penalty to the coefficients, which performs
96/// automatic variable selection by shrinking some coefficients to exactly zero.
97/// The intercept is never penalized.
98///
99/// # Arguments
100///
101/// * `y_json` - JSON array of response variable values
102/// * `x_vars_json` - JSON array of predictor arrays
103/// * `variable_names` - JSON array of variable names
104/// * `lambda` - Regularization strength (>= 0, typical range 0.01 to 10)
105/// * `standardize` - Whether to standardize predictors (recommended: true)
106/// * `max_iter` - Maximum coordinate descent iterations (default: 100000)
107/// * `tol` - Convergence tolerance (default: 1e-7)
108///
109/// # Returns
110///
111/// JSON string containing:
112/// - `lambda` - Lambda value used
113/// - `intercept` - Intercept coefficient
114/// - `coefficients` - Slope coefficients (some may be exactly zero)
115/// - `fitted_values` - Predictions on training data
116/// - `residuals` - Residuals (y - fitted_values)
117/// - `n_nonzero` - Number of non-zero coefficients (excluding intercept)
118/// - `iterations` - Number of coordinate descent iterations
119/// - `converged` - Whether the algorithm converged
120///
121/// # Errors
122///
123/// Returns a JSON error object if parsing fails, lambda is negative,
124/// or domain check fails.
125#[wasm_bindgen]
126pub fn lasso_regression(
127    y_json: &str,
128    x_vars_json: &str,
129    _variable_names: &str,
130    lambda: f64,
131    standardize: bool,
132    max_iter: usize,
133    tol: f64,
134) -> String {
135    if let Err(e) = check_domain() {
136        return error_to_json(&e);
137    }
138
139    // Parse JSON input
140    let y: Vec<f64> = match serde_json::from_str(y_json) {
141        Ok(v) => v,
142        Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
143    };
144
145    let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
146        Ok(v) => v,
147        Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
148    };
149
150    // Build design matrix with intercept column
151    let (x, n, p) = build_design_matrix(&y, &x_vars);
152
153    if n <= p + 1 {
154        return error_json(&format!(
155            "Insufficient data: need at least {} observations for {} predictors",
156            p + 2,
157            p
158        ));
159    }
160
161    // Configure lasso options
162    let options = regularized::lasso::LassoFitOptions {
163        lambda,
164        intercept: true,
165        standardize,
166        max_iter,
167        tol,
168        ..Default::default()
169    };
170
171    match regularized::lasso::lasso_fit(&x, &y, &options) {
172        Ok(output) => serde_json::to_string(&output)
173            .unwrap_or_else(|_| error_json("Failed to serialize lasso regression result")),
174        Err(e) => error_json(&e.to_string()),
175    }
176}
177
178/// Performs Elastic Net regression via WASM.
179///
180/// Elastic Net combines L1 (Lasso) and L2 (Ridge) penalties.
181///
182/// # Arguments
183///
184/// * `y_json` - JSON array of response variable values
185/// * `x_vars_json` - JSON array of predictor arrays
186/// * `variable_names` - JSON array of variable names
187/// * `lambda` - Regularization strength (>= 0)
188/// * `alpha` - Elastic net mixing parameter (0 = Ridge, 1 = Lasso)
189/// * `standardize` - Whether to standardize predictors (recommended: true)
190/// * `max_iter` - Maximum coordinate descent iterations
191/// * `tol` - Convergence tolerance
192///
193/// # Returns
194///
195/// JSON string containing regression results (same fields as Lasso).
196///
197/// # Errors
198///
199/// Returns a JSON error object if parsing fails, parameters are invalid,
200/// or domain check fails.
201#[wasm_bindgen]
202#[allow(clippy::too_many_arguments)]
203pub fn elastic_net_regression(
204    y_json: &str,
205    x_vars_json: &str,
206    _variable_names: &str,
207    lambda: f64,
208    alpha: f64,
209    standardize: bool,
210    max_iter: usize,
211    tol: f64,
212) -> String {
213    if let Err(e) = check_domain() {
214        return error_to_json(&e);
215    }
216
217    // Parse JSON input
218    let y: Vec<f64> = match serde_json::from_str(y_json) {
219        Ok(v) => v,
220        Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
221    };
222
223    let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
224        Ok(v) => v,
225        Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
226    };
227
228    // Build design matrix with intercept column
229    let (x, n, p) = build_design_matrix(&y, &x_vars);
230
231    if n <= p + 1 {
232        return error_json(&format!(
233            "Insufficient data: need at least {} observations for {} predictors",
234            p + 2,
235            p
236        ));
237    }
238
239    // Configure elastic net options
240    let options = regularized::elastic_net::ElasticNetOptions {
241        lambda,
242        alpha,
243        intercept: true,
244        standardize,
245        max_iter,
246        tol,
247        ..Default::default()
248    };
249
250    match regularized::elastic_net::elastic_net_fit(&x, &y, &options) {
251        Ok(output) => serde_json::to_string(&output)
252            .unwrap_or_else(|_| error_json("Failed to serialize elastic net regression result")),
253        Err(e) => error_json(&e.to_string()),
254    }
255}
256
257/// Generates a lambda path for regularized regression via WASM.
258///
259/// Creates a logarithmically-spaced sequence of lambda values from lambda_max
260/// (where all penalized coefficients are zero) down to lambda_min. This is
261/// useful for fitting regularization paths and selecting optimal lambda via
262/// cross-validation.
263///
264/// # Arguments
265///
266/// * `y_json` - JSON array of response variable values
267/// * `x_vars_json` - JSON array of predictor arrays
268/// * `n_lambda` - Number of lambda values to generate (default: 100)
269/// * `lambda_min_ratio` - Ratio for smallest lambda (default: 0.0001 if n >= p, else 0.01)
270///
271/// # Returns
272///
273/// JSON string containing:
274/// - `lambda_path` - Array of lambda values in decreasing order
275/// - `lambda_max` - Maximum lambda value
276/// - `lambda_min` - Minimum lambda value
277/// - `n_lambda` - Number of lambda values
278///
279/// # Errors
280///
281/// Returns a JSON error object if parsing fails or domain check fails.
282#[wasm_bindgen]
283pub fn make_lambda_path(
284    y_json: &str,
285    x_vars_json: &str,
286    n_lambda: usize,
287    lambda_min_ratio: f64,
288) -> String {
289    if let Err(e) = check_domain() {
290        return error_to_json(&e);
291    }
292
293    // Parse JSON input
294    let y: Vec<f64> = match serde_json::from_str(y_json) {
295        Ok(v) => v,
296        Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
297    };
298
299    let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
300        Ok(v) => v,
301        Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
302    };
303
304    // Build design matrix with intercept column
305    let (x, n, p) = build_design_matrix(&y, &x_vars);
306
307    // Standardize X for lambda path computation
308    let x_mean: Vec<f64> = (0..x.cols)
309        .map(|j| {
310            if j == 0 {
311                1.0 // Intercept column
312            } else {
313                (0..n).map(|i| x.get(i, j)).sum::<f64>() / n as f64
314            }
315        })
316        .collect();
317
318    let x_standardized: Vec<f64> = (0..x.cols)
319        .map(|j| {
320            if j == 0 {
321                0.0 // Intercept column - no centering
322            } else {
323                let mean = x_mean[j];
324                let variance =
325                    (0..n).map(|i| (x.get(i, j) - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
326                variance.sqrt()
327            }
328        })
329        .collect();
330
331    // Build standardized X matrix
332    let mut x_standardized_data = vec![1.0; n * (p + 1)];
333    for j in 0..x.cols {
334        for i in 0..n {
335            if j == 0 {
336                x_standardized_data[i * (p + 1)] = 1.0; // Intercept
337            } else {
338                let std = x_standardized[j];
339                if std > 1e-10 {
340                    x_standardized_data[i * (p + 1) + j] = (x.get(i, j) - x_mean[j]) / std;
341                } else {
342                    x_standardized_data[i * (p + 1) + j] = 0.0;
343                }
344            }
345        }
346    }
347    let x_standardized = linalg::Matrix::new(n, p + 1, x_standardized_data);
348
349    // Center y
350    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
351    let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
352
353    // Configure lambda path options
354    let options = regularized::path::LambdaPathOptions {
355        nlambda: n_lambda.max(1),
356        lambda_min_ratio: if lambda_min_ratio > 0.0 {
357            Some(lambda_min_ratio)
358        } else {
359            None
360        },
361        alpha: 1.0, // Lasso
362        ..Default::default()
363    };
364
365    let lambda_path =
366        regularized::path::make_lambda_path(&x_standardized, &y_centered, &options, None, Some(0));
367
368    let lambda_max = lambda_path.first().copied().unwrap_or(0.0);
369    let lambda_min = lambda_path.last().copied().unwrap_or(0.0);
370
371    // Return as JSON (note: infinity serializes as null in JSON, handled in JS)
372    let result = serde_json::json!({
373        "lambda_path": lambda_path,
374        "lambda_max": lambda_max,
375        "lambda_min": lambda_min,
376        "n_lambda": lambda_path.len()
377    });
378
379    result.to_string()
380}
381
382/// Helper function to build a design matrix from column vectors.
383///
384/// # Arguments
385///
386/// * `y` - Response variable (used to determine n)
387/// * `x_vars` - Predictor column vectors
388///
389/// # Returns
390///
391/// A tuple of (Matrix, n, p) where p is the number of predictors (excluding intercept)
392fn build_design_matrix(y: &[f64], x_vars: &[Vec<f64>]) -> (linalg::Matrix, usize, usize) {
393    let n = y.len();
394    let p = x_vars.len();
395
396    let mut x_data = vec![1.0; n * (p + 1)]; // Intercept column
397    for (j, x_var) in x_vars.iter().enumerate() {
398        for (i, &val) in x_var.iter().enumerate() {
399            x_data[i * (p + 1) + j + 1] = val;
400        }
401    }
402
403    (linalg::Matrix::new(n, p + 1, x_data), n, p)
404}