1#![cfg(feature = "wasm")]
6
7use wasm_bindgen::prelude::*;
8
9use super::domain::check_domain;
10use crate::error::{error_json, error_to_json};
11use crate::linalg;
12use crate::regularized;
13
14#[wasm_bindgen]
42pub fn ridge_regression(
43 y_json: &str,
44 x_vars_json: &str,
45 _variable_names: &str,
46 lambda: f64,
47 standardize: bool,
48) -> String {
49 if let Err(e) = check_domain() {
50 return error_to_json(&e);
51 }
52
53 let y: Vec<f64> = match serde_json::from_str(y_json) {
55 Ok(v) => v,
56 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
57 };
58
59 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
60 Ok(v) => v,
61 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
62 };
63
64 let (x, n, p) = build_design_matrix(&y, &x_vars);
66
67 if n <= p + 1 {
68 return error_json(&format!(
69 "Insufficient data: need at least {} observations for {} predictors",
70 p + 2,
71 p
72 ));
73 }
74
75 let options = regularized::ridge::RidgeFitOptions {
77 lambda,
78 intercept: true,
79 standardize,
80 max_iter: 100000,
81 tol: 1e-7,
82 warm_start: None,
83 weights: None,
84 };
85
86 match regularized::ridge::ridge_fit(&x, &y, &options) {
87 Ok(output) => serde_json::to_string(&output)
88 .unwrap_or_else(|_| error_json("Failed to serialize ridge regression result")),
89 Err(e) => error_json(&e.to_string()),
90 }
91}
92
93#[wasm_bindgen]
126pub fn lasso_regression(
127 y_json: &str,
128 x_vars_json: &str,
129 _variable_names: &str,
130 lambda: f64,
131 standardize: bool,
132 max_iter: usize,
133 tol: f64,
134) -> String {
135 if let Err(e) = check_domain() {
136 return error_to_json(&e);
137 }
138
139 let y: Vec<f64> = match serde_json::from_str(y_json) {
141 Ok(v) => v,
142 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
143 };
144
145 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
146 Ok(v) => v,
147 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
148 };
149
150 let (x, n, p) = build_design_matrix(&y, &x_vars);
152
153 if n <= p + 1 {
154 return error_json(&format!(
155 "Insufficient data: need at least {} observations for {} predictors",
156 p + 2,
157 p
158 ));
159 }
160
161 let options = regularized::lasso::LassoFitOptions {
163 lambda,
164 intercept: true,
165 standardize,
166 max_iter,
167 tol,
168 ..Default::default()
169 };
170
171 match regularized::lasso::lasso_fit(&x, &y, &options) {
172 Ok(output) => serde_json::to_string(&output)
173 .unwrap_or_else(|_| error_json("Failed to serialize lasso regression result")),
174 Err(e) => error_json(&e.to_string()),
175 }
176}
177
178#[wasm_bindgen]
202#[allow(clippy::too_many_arguments)]
203pub fn elastic_net_regression(
204 y_json: &str,
205 x_vars_json: &str,
206 _variable_names: &str,
207 lambda: f64,
208 alpha: f64,
209 standardize: bool,
210 max_iter: usize,
211 tol: f64,
212) -> String {
213 if let Err(e) = check_domain() {
214 return error_to_json(&e);
215 }
216
217 let y: Vec<f64> = match serde_json::from_str(y_json) {
219 Ok(v) => v,
220 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
221 };
222
223 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
224 Ok(v) => v,
225 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
226 };
227
228 let (x, n, p) = build_design_matrix(&y, &x_vars);
230
231 if n <= p + 1 {
232 return error_json(&format!(
233 "Insufficient data: need at least {} observations for {} predictors",
234 p + 2,
235 p
236 ));
237 }
238
239 let options = regularized::elastic_net::ElasticNetOptions {
241 lambda,
242 alpha,
243 intercept: true,
244 standardize,
245 max_iter,
246 tol,
247 ..Default::default()
248 };
249
250 match regularized::elastic_net::elastic_net_fit(&x, &y, &options) {
251 Ok(output) => serde_json::to_string(&output)
252 .unwrap_or_else(|_| error_json("Failed to serialize elastic net regression result")),
253 Err(e) => error_json(&e.to_string()),
254 }
255}
256
257#[derive(serde::Serialize)]
259struct PathResult {
260 lambdas: Vec<f64>,
261 coefficients: Vec<Vec<f64>>,
262 r_squared: Vec<f64>,
263 aic: Vec<f64>,
264 bic: Vec<f64>,
265 n_nonzero: Vec<usize>,
266}
267
268#[wasm_bindgen]
288#[allow(clippy::too_many_arguments)]
289pub fn elastic_net_path_wasm(
290 y_json: &str,
291 x_vars_json: &str,
292 n_lambda: usize,
293 lambda_min_ratio: f64,
294 alpha: f64,
295 standardize: bool,
296 max_iter: usize,
297 tol: f64,
298) -> String {
299 if let Err(e) = check_domain() {
300 return error_to_json(&e);
301 }
302
303 let y: Vec<f64> = match serde_json::from_str(y_json) {
305 Ok(v) => v,
306 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
307 };
308
309 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
310 Ok(v) => v,
311 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
312 };
313
314 let (x, n, p) = build_design_matrix(&y, &x_vars);
316
317 if n <= p + 1 {
318 return error_json(&format!(
319 "Insufficient data: need at least {} observations for {} predictors",
320 p + 2,
321 p
322 ));
323 }
324
325 let path_options = regularized::path::LambdaPathOptions {
327 nlambda: n_lambda.max(1),
328 lambda_min_ratio: if lambda_min_ratio > 0.0 {
329 Some(lambda_min_ratio)
330 } else {
331 None
332 },
333 alpha,
334 ..Default::default()
335 };
336
337 let fit_options = regularized::elastic_net::ElasticNetOptions {
338 lambda: 0.0, alpha,
340 intercept: true,
341 standardize,
342 max_iter,
343 tol,
344 ..Default::default()
345 };
346
347 match regularized::elastic_net::elastic_net_path(&x, &y, &path_options, &fit_options) {
348 Ok(fits) => {
349 let result = PathResult {
351 lambdas: fits.iter().map(|f| f.lambda).collect(),
352 coefficients: fits.iter().map(|f| f.coefficients.clone()).collect(),
353 r_squared: fits.iter().map(|f| f.r_squared).collect(),
354 aic: fits.iter().map(|f| f.aic).collect(),
355 bic: fits.iter().map(|f| f.bic).collect(),
356 n_nonzero: fits.iter().map(|f| f.n_nonzero).collect(),
357 };
358
359 serde_json::to_string(&result)
360 .unwrap_or_else(|_| error_json("Failed to serialize elastic net path result"))
361 },
362 Err(e) => error_json(&e.to_string()),
363 }
364}
365
366#[wasm_bindgen]
392pub fn make_lambda_path(
393 y_json: &str,
394 x_vars_json: &str,
395 n_lambda: usize,
396 lambda_min_ratio: f64,
397) -> String {
398 if let Err(e) = check_domain() {
399 return error_to_json(&e);
400 }
401
402 let y: Vec<f64> = match serde_json::from_str(y_json) {
404 Ok(v) => v,
405 Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
406 };
407
408 let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
409 Ok(v) => v,
410 Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
411 };
412
413 let (x, n, p) = build_design_matrix(&y, &x_vars);
415
416 let x_mean: Vec<f64> = (0..x.cols)
418 .map(|j| {
419 if j == 0 {
420 1.0 } else {
422 (0..n).map(|i| x.get(i, j)).sum::<f64>() / n as f64
423 }
424 })
425 .collect();
426
427 let x_standardized: Vec<f64> = (0..x.cols)
428 .map(|j| {
429 if j == 0 {
430 0.0 } else {
432 let mean = x_mean[j];
433 let variance =
434 (0..n).map(|i| (x.get(i, j) - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
435 variance.sqrt()
436 }
437 })
438 .collect();
439
440 let mut x_standardized_data = vec![1.0; n * (p + 1)];
442 for j in 0..x.cols {
443 for i in 0..n {
444 if j == 0 {
445 x_standardized_data[i * (p + 1)] = 1.0; } else {
447 let std = x_standardized[j];
448 if std > 1e-10 {
449 x_standardized_data[i * (p + 1) + j] = (x.get(i, j) - x_mean[j]) / std;
450 } else {
451 x_standardized_data[i * (p + 1) + j] = 0.0;
452 }
453 }
454 }
455 }
456 let x_standardized = linalg::Matrix::new(n, p + 1, x_standardized_data);
457
458 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
460 let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
461
462 let options = regularized::path::LambdaPathOptions {
464 nlambda: n_lambda.max(1),
465 lambda_min_ratio: if lambda_min_ratio > 0.0 {
466 Some(lambda_min_ratio)
467 } else {
468 None
469 },
470 alpha: 1.0, ..Default::default()
472 };
473
474 let lambda_path =
475 regularized::path::make_lambda_path(&x_standardized, &y_centered, &options, None, Some(0));
476
477 let lambda_max = lambda_path.first().copied().unwrap_or(0.0);
478 let lambda_min = lambda_path.last().copied().unwrap_or(0.0);
479
480 let result = serde_json::json!({
482 "lambda_path": lambda_path,
483 "lambda_max": lambda_max,
484 "lambda_min": lambda_min,
485 "n_lambda": lambda_path.len()
486 });
487
488 result.to_string()
489}
490
491fn build_design_matrix(y: &[f64], x_vars: &[Vec<f64>]) -> (linalg::Matrix, usize, usize) {
502 let n = y.len();
503 let p = x_vars.len();
504
505 let mut x_data = vec![1.0; n * (p + 1)]; for (j, x_var) in x_vars.iter().enumerate() {
507 for (i, &val) in x_var.iter().enumerate() {
508 x_data[i * (p + 1) + j + 1] = val;
509 }
510 }
511
512 (linalg::Matrix::new(n, p + 1, x_data), n, p)
513}