linreg_core/wasm/regularized.rs
1//! Regularized regression methods for WASM
2//!
3//! Provides WASM bindings for Ridge, Lasso, and Elastic Net regression.
4
5#![cfg(feature = "wasm")]
6
7use wasm_bindgen::prelude::*;
8
9use super::domain::check_domain;
10use crate::error::{error_json, error_to_json};
11use crate::linalg;
12use crate::regularized;
13
14/// Performs Ridge regression via WASM.
15///
16/// Ridge regression adds an L2 penalty to the coefficients, which helps with
17/// multicollinearity and overfitting. The intercept is never penalized.
18///
19/// # Arguments
20///
21/// * `y_json` - JSON array of response variable values
22/// * `x_vars_json` - JSON array of predictor arrays
23/// * `variable_names` - JSON array of variable names
24/// * `lambda` - Regularization strength (>= 0, typical range 0.01 to 100)
25/// * `standardize` - Whether to standardize predictors (recommended: true)
26///
27/// # Returns
28///
29/// JSON string containing:
30/// - `lambda` - Lambda value used
31/// - `intercept` - Intercept coefficient
32/// - `coefficients` - Slope coefficients
33/// - `fitted_values` - Predictions on training data
34/// - `residuals` - Residuals (y - fitted_values)
35/// - `df` - Effective degrees of freedom
36///
37/// # Errors
38///
39/// Returns a JSON error object if parsing fails, lambda is negative,
40/// or domain check fails.
41#[wasm_bindgen]
42pub fn ridge_regression(
43 y_json: &str,
44 x_vars_json: &str,
45 _variable_names: &str,
46 lambda: f64,
47 standardize: bool,
48) -> String {
49 if let Err(e) = check_domain() {
50 return error_to_json(&e);
51 }
52
53 // Parse JSON input
54 let y: Vec<f64> = match serde_json::from_str(y_json) {
55 Ok(v) => v,
56 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
57 };
58
59 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
60 Ok(v) => v,
61 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
62 };
63
64 // Build design matrix with intercept column
65 let (x, n, p) = build_design_matrix(&y, &x_vars);
66
67 if n <= p + 1 {
68 return error_json(&format!(
69 "Insufficient data: need at least {} observations for {} predictors",
70 p + 2,
71 p
72 ));
73 }
74
75 // Configure ridge options
76 let options = regularized::ridge::RidgeFitOptions {
77 lambda,
78 intercept: true,
79 standardize,
80 max_iter: 100000,
81 tol: 1e-7,
82 warm_start: None,
83 weights: None,
84 };
85
86 match regularized::ridge::ridge_fit(&x, &y, &options) {
87 Ok(output) => serde_json::to_string(&output)
88 .unwrap_or_else(|_| error_json("Failed to serialize ridge regression result")),
89 Err(e) => error_json(&e.to_string()),
90 }
91}
92
93/// Performs Lasso regression via WASM.
94///
95/// Lasso regression adds an L1 penalty to the coefficients, which performs
96/// automatic variable selection by shrinking some coefficients to exactly zero.
97/// The intercept is never penalized.
98///
99/// # Arguments
100///
101/// * `y_json` - JSON array of response variable values
102/// * `x_vars_json` - JSON array of predictor arrays
103/// * `variable_names` - JSON array of variable names
104/// * `lambda` - Regularization strength (>= 0, typical range 0.01 to 10)
105/// * `standardize` - Whether to standardize predictors (recommended: true)
106/// * `max_iter` - Maximum coordinate descent iterations (default: 100000)
107/// * `tol` - Convergence tolerance (default: 1e-7)
108///
109/// # Returns
110///
111/// JSON string containing:
112/// - `lambda` - Lambda value used
113/// - `intercept` - Intercept coefficient
114/// - `coefficients` - Slope coefficients (some may be exactly zero)
115/// - `fitted_values` - Predictions on training data
116/// - `residuals` - Residuals (y - fitted_values)
117/// - `n_nonzero` - Number of non-zero coefficients (excluding intercept)
118/// - `iterations` - Number of coordinate descent iterations
119/// - `converged` - Whether the algorithm converged
120///
121/// # Errors
122///
123/// Returns a JSON error object if parsing fails, lambda is negative,
124/// or domain check fails.
125#[wasm_bindgen]
126pub fn lasso_regression(
127 y_json: &str,
128 x_vars_json: &str,
129 _variable_names: &str,
130 lambda: f64,
131 standardize: bool,
132 max_iter: usize,
133 tol: f64,
134) -> String {
135 if let Err(e) = check_domain() {
136 return error_to_json(&e);
137 }
138
139 // Parse JSON input
140 let y: Vec<f64> = match serde_json::from_str(y_json) {
141 Ok(v) => v,
142 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
143 };
144
145 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
146 Ok(v) => v,
147 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
148 };
149
150 // Build design matrix with intercept column
151 let (x, n, p) = build_design_matrix(&y, &x_vars);
152
153 if n <= p + 1 {
154 return error_json(&format!(
155 "Insufficient data: need at least {} observations for {} predictors",
156 p + 2,
157 p
158 ));
159 }
160
161 // Configure lasso options
162 let options = regularized::lasso::LassoFitOptions {
163 lambda,
164 intercept: true,
165 standardize,
166 max_iter,
167 tol,
168 ..Default::default()
169 };
170
171 match regularized::lasso::lasso_fit(&x, &y, &options) {
172 Ok(output) => serde_json::to_string(&output)
173 .unwrap_or_else(|_| error_json("Failed to serialize lasso regression result")),
174 Err(e) => error_json(&e.to_string()),
175 }
176}
177
178/// Performs Elastic Net regression via WASM.
179///
180/// Elastic Net combines L1 (Lasso) and L2 (Ridge) penalties.
181///
182/// # Arguments
183///
184/// * `y_json` - JSON array of response variable values
185/// * `x_vars_json` - JSON array of predictor arrays
186/// * `variable_names` - JSON array of variable names
187/// * `lambda` - Regularization strength (>= 0)
188/// * `alpha` - Elastic net mixing parameter (0 = Ridge, 1 = Lasso)
189/// * `standardize` - Whether to standardize predictors (recommended: true)
190/// * `max_iter` - Maximum coordinate descent iterations
191/// * `tol` - Convergence tolerance
192///
193/// # Returns
194///
195/// JSON string containing regression results (same fields as Lasso).
196///
197/// # Errors
198///
199/// Returns a JSON error object if parsing fails, parameters are invalid,
200/// or domain check fails.
201#[wasm_bindgen]
202#[allow(clippy::too_many_arguments)]
203pub fn elastic_net_regression(
204 y_json: &str,
205 x_vars_json: &str,
206 _variable_names: &str,
207 lambda: f64,
208 alpha: f64,
209 standardize: bool,
210 max_iter: usize,
211 tol: f64,
212) -> String {
213 if let Err(e) = check_domain() {
214 return error_to_json(&e);
215 }
216
217 // Parse JSON input
218 let y: Vec<f64> = match serde_json::from_str(y_json) {
219 Ok(v) => v,
220 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
221 };
222
223 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
224 Ok(v) => v,
225 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
226 };
227
228 // Build design matrix with intercept column
229 let (x, n, p) = build_design_matrix(&y, &x_vars);
230
231 if n <= p + 1 {
232 return error_json(&format!(
233 "Insufficient data: need at least {} observations for {} predictors",
234 p + 2,
235 p
236 ));
237 }
238
239 // Configure elastic net options
240 let options = regularized::elastic_net::ElasticNetOptions {
241 lambda,
242 alpha,
243 intercept: true,
244 standardize,
245 max_iter,
246 tol,
247 ..Default::default()
248 };
249
250 match regularized::elastic_net::elastic_net_fit(&x, &y, &options) {
251 Ok(output) => serde_json::to_string(&output)
252 .unwrap_or_else(|_| error_json("Failed to serialize elastic net regression result")),
253 Err(e) => error_json(&e.to_string()),
254 }
255}
256
257/// Generates a lambda path for regularized regression via WASM.
258///
259/// Creates a logarithmically-spaced sequence of lambda values from lambda_max
260/// (where all penalized coefficients are zero) down to lambda_min. This is
261/// useful for fitting regularization paths and selecting optimal lambda via
262/// cross-validation.
263///
264/// # Arguments
265///
266/// * `y_json` - JSON array of response variable values
267/// * `x_vars_json` - JSON array of predictor arrays
268/// * `n_lambda` - Number of lambda values to generate (default: 100)
269/// * `lambda_min_ratio` - Ratio for smallest lambda (default: 0.0001 if n >= p, else 0.01)
270///
271/// # Returns
272///
273/// JSON string containing:
274/// - `lambda_path` - Array of lambda values in decreasing order
275/// - `lambda_max` - Maximum lambda value
276/// - `lambda_min` - Minimum lambda value
277/// - `n_lambda` - Number of lambda values
278///
279/// # Errors
280///
281/// Returns a JSON error object if parsing fails or domain check fails.
282#[wasm_bindgen]
283pub fn make_lambda_path(
284 y_json: &str,
285 x_vars_json: &str,
286 n_lambda: usize,
287 lambda_min_ratio: f64,
288) -> String {
289 if let Err(e) = check_domain() {
290 return error_to_json(&e);
291 }
292
293 // Parse JSON input
294 let y: Vec<f64> = match serde_json::from_str(y_json) {
295 Ok(v) => v,
296 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
297 };
298
299 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
300 Ok(v) => v,
301 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
302 };
303
304 // Build design matrix with intercept column
305 let (x, n, p) = build_design_matrix(&y, &x_vars);
306
307 // Standardize X for lambda path computation
308 let x_mean: Vec<f64> = (0..x.cols)
309 .map(|j| {
310 if j == 0 {
311 1.0 // Intercept column
312 } else {
313 (0..n).map(|i| x.get(i, j)).sum::<f64>() / n as f64
314 }
315 })
316 .collect();
317
318 let x_standardized: Vec<f64> = (0..x.cols)
319 .map(|j| {
320 if j == 0 {
321 0.0 // Intercept column - no centering
322 } else {
323 let mean = x_mean[j];
324 let variance =
325 (0..n).map(|i| (x.get(i, j) - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
326 variance.sqrt()
327 }
328 })
329 .collect();
330
331 // Build standardized X matrix
332 let mut x_standardized_data = vec![1.0; n * (p + 1)];
333 for j in 0..x.cols {
334 for i in 0..n {
335 if j == 0 {
336 x_standardized_data[i * (p + 1)] = 1.0; // Intercept
337 } else {
338 let std = x_standardized[j];
339 if std > 1e-10 {
340 x_standardized_data[i * (p + 1) + j] = (x.get(i, j) - x_mean[j]) / std;
341 } else {
342 x_standardized_data[i * (p + 1) + j] = 0.0;
343 }
344 }
345 }
346 }
347 let x_standardized = linalg::Matrix::new(n, p + 1, x_standardized_data);
348
349 // Center y
350 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
351 let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
352
353 // Configure lambda path options
354 let options = regularized::path::LambdaPathOptions {
355 nlambda: n_lambda.max(1),
356 lambda_min_ratio: if lambda_min_ratio > 0.0 {
357 Some(lambda_min_ratio)
358 } else {
359 None
360 },
361 alpha: 1.0, // Lasso
362 ..Default::default()
363 };
364
365 let lambda_path =
366 regularized::path::make_lambda_path(&x_standardized, &y_centered, &options, None, Some(0));
367
368 let lambda_max = lambda_path.first().copied().unwrap_or(0.0);
369 let lambda_min = lambda_path.last().copied().unwrap_or(0.0);
370
371 // Return as JSON (note: infinity serializes as null in JSON, handled in JS)
372 let result = serde_json::json!({
373 "lambda_path": lambda_path,
374 "lambda_max": lambda_max,
375 "lambda_min": lambda_min,
376 "n_lambda": lambda_path.len()
377 });
378
379 result.to_string()
380}
381
382/// Helper function to build a design matrix from column vectors.
383///
384/// # Arguments
385///
386/// * `y` - Response variable (used to determine n)
387/// * `x_vars` - Predictor column vectors
388///
389/// # Returns
390///
391/// A tuple of (Matrix, n, p) where p is the number of predictors (excluding intercept)
392fn build_design_matrix(y: &[f64], x_vars: &[Vec<f64>]) -> (linalg::Matrix, usize, usize) {
393 let n = y.len();
394 let p = x_vars.len();
395
396 let mut x_data = vec![1.0; n * (p + 1)]; // Intercept column
397 for (j, x_var) in x_vars.iter().enumerate() {
398 for (i, &val) in x_var.iter().enumerate() {
399 x_data[i * (p + 1) + j + 1] = val;
400 }
401 }
402
403 (linalg::Matrix::new(n, p + 1, x_data), n, p)
404}