#![cfg(feature = "wasm")]
use wasm_bindgen::prelude::*;
use super::domain::check_domain;
use crate::error::{error_json, error_to_json};
use crate::linalg;
use crate::regularized;
#[wasm_bindgen]
pub fn ridge_regression(
y_json: &str,
x_vars_json: &str,
_variable_names: &str,
lambda: f64,
standardize: bool,
) -> String {
if let Err(e) = check_domain() {
return error_to_json(&e);
}
let y: Vec<f64> = match serde_json::from_str(y_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
};
let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
};
let (x, n, p) = build_design_matrix(&y, &x_vars);
if n <= p + 1 {
return error_json(&format!(
"Insufficient data: need at least {} observations for {} predictors",
p + 2,
p
));
}
let options = regularized::ridge::RidgeFitOptions {
lambda,
intercept: true,
standardize,
max_iter: 100000,
tol: 1e-7,
warm_start: None,
weights: None,
};
match regularized::ridge::ridge_fit(&x, &y, &options) {
Ok(output) => serde_json::to_string(&output)
.unwrap_or_else(|_| error_json("Failed to serialize ridge regression result")),
Err(e) => error_json(&e.to_string()),
}
}
#[wasm_bindgen]
pub fn lasso_regression(
y_json: &str,
x_vars_json: &str,
_variable_names: &str,
lambda: f64,
standardize: bool,
max_iter: usize,
tol: f64,
) -> String {
if let Err(e) = check_domain() {
return error_to_json(&e);
}
let y: Vec<f64> = match serde_json::from_str(y_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
};
let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
};
let (x, n, p) = build_design_matrix(&y, &x_vars);
if n <= p + 1 {
return error_json(&format!(
"Insufficient data: need at least {} observations for {} predictors",
p + 2,
p
));
}
let options = regularized::lasso::LassoFitOptions {
lambda,
intercept: true,
standardize,
max_iter,
tol,
..Default::default()
};
match regularized::lasso::lasso_fit(&x, &y, &options) {
Ok(output) => serde_json::to_string(&output)
.unwrap_or_else(|_| error_json("Failed to serialize lasso regression result")),
Err(e) => error_json(&e.to_string()),
}
}
#[wasm_bindgen]
#[allow(clippy::too_many_arguments)]
pub fn elastic_net_regression(
y_json: &str,
x_vars_json: &str,
_variable_names: &str,
lambda: f64,
alpha: f64,
standardize: bool,
max_iter: usize,
tol: f64,
) -> String {
if let Err(e) = check_domain() {
return error_to_json(&e);
}
let y: Vec<f64> = match serde_json::from_str(y_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
};
let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
};
let (x, n, p) = build_design_matrix(&y, &x_vars);
if n <= p + 1 {
return error_json(&format!(
"Insufficient data: need at least {} observations for {} predictors",
p + 2,
p
));
}
let options = regularized::elastic_net::ElasticNetOptions {
lambda,
alpha,
intercept: true,
standardize,
max_iter,
tol,
..Default::default()
};
match regularized::elastic_net::elastic_net_fit(&x, &y, &options) {
Ok(output) => serde_json::to_string(&output)
.unwrap_or_else(|_| error_json("Failed to serialize elastic net regression result")),
Err(e) => error_json(&e.to_string()),
}
}
#[derive(serde::Serialize)]
struct PathResult {
lambdas: Vec<f64>,
coefficients: Vec<Vec<f64>>,
r_squared: Vec<f64>,
aic: Vec<f64>,
bic: Vec<f64>,
n_nonzero: Vec<usize>,
}
#[wasm_bindgen]
#[allow(clippy::too_many_arguments)]
pub fn elastic_net_path_wasm(
y_json: &str,
x_vars_json: &str,
n_lambda: usize,
lambda_min_ratio: f64,
alpha: f64,
standardize: bool,
max_iter: usize,
tol: f64,
) -> String {
if let Err(e) = check_domain() {
return error_to_json(&e);
}
let y: Vec<f64> = match serde_json::from_str(y_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
};
let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
};
let (x, n, p) = build_design_matrix(&y, &x_vars);
if n <= p + 1 {
return error_json(&format!(
"Insufficient data: need at least {} observations for {} predictors",
p + 2,
p
));
}
let path_options = regularized::path::LambdaPathOptions {
nlambda: n_lambda.max(1),
lambda_min_ratio: if lambda_min_ratio > 0.0 {
Some(lambda_min_ratio)
} else {
None
},
alpha,
..Default::default()
};
let fit_options = regularized::elastic_net::ElasticNetOptions {
lambda: 0.0, alpha,
intercept: true,
standardize,
max_iter,
tol,
..Default::default()
};
match regularized::elastic_net::elastic_net_path(&x, &y, &path_options, &fit_options) {
Ok(fits) => {
let result = PathResult {
lambdas: fits.iter().map(|f| f.lambda).collect(),
coefficients: fits.iter().map(|f| f.coefficients.clone()).collect(),
r_squared: fits.iter().map(|f| f.r_squared).collect(),
aic: fits.iter().map(|f| f.aic).collect(),
bic: fits.iter().map(|f| f.bic).collect(),
n_nonzero: fits.iter().map(|f| f.n_nonzero).collect(),
};
serde_json::to_string(&result)
.unwrap_or_else(|_| error_json("Failed to serialize elastic net path result"))
},
Err(e) => error_json(&e.to_string()),
}
}
#[wasm_bindgen]
pub fn make_lambda_path(
y_json: &str,
x_vars_json: &str,
n_lambda: usize,
lambda_min_ratio: f64,
) -> String {
if let Err(e) = check_domain() {
return error_to_json(&e);
}
let y: Vec<f64> = match serde_json::from_str(y_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse y: {}", e)),
};
let x_vars: Vec<Vec<f64>> = match serde_json::from_str(x_vars_json) {
Ok(v) => v,
Err(e) => return error_json(&format!("Failed to parse x_vars: {}", e)),
};
let (x, n, p) = build_design_matrix(&y, &x_vars);
let x_mean: Vec<f64> = (0..x.cols)
.map(|j| {
if j == 0 {
1.0 } else {
(0..n).map(|i| x.get(i, j)).sum::<f64>() / n as f64
}
})
.collect();
let x_standardized: Vec<f64> = (0..x.cols)
.map(|j| {
if j == 0 {
0.0 } else {
let mean = x_mean[j];
let variance =
(0..n).map(|i| (x.get(i, j) - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
variance.sqrt()
}
})
.collect();
let mut x_standardized_data = vec![1.0; n * (p + 1)];
for j in 0..x.cols {
for i in 0..n {
if j == 0 {
x_standardized_data[i * (p + 1)] = 1.0; } else {
let std = x_standardized[j];
if std > 1e-10 {
x_standardized_data[i * (p + 1) + j] = (x.get(i, j) - x_mean[j]) / std;
} else {
x_standardized_data[i * (p + 1) + j] = 0.0;
}
}
}
}
let x_standardized = linalg::Matrix::new(n, p + 1, x_standardized_data);
let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
let options = regularized::path::LambdaPathOptions {
nlambda: n_lambda.max(1),
lambda_min_ratio: if lambda_min_ratio > 0.0 {
Some(lambda_min_ratio)
} else {
None
},
alpha: 1.0, ..Default::default()
};
let lambda_path =
regularized::path::make_lambda_path(&x_standardized, &y_centered, &options, None, Some(0));
let lambda_max = lambda_path.first().copied().unwrap_or(0.0);
let lambda_min = lambda_path.last().copied().unwrap_or(0.0);
let result = serde_json::json!({
"lambda_path": lambda_path,
"lambda_max": lambda_max,
"lambda_min": lambda_min,
"n_lambda": lambda_path.len()
});
result.to_string()
}
fn build_design_matrix(y: &[f64], x_vars: &[Vec<f64>]) -> (linalg::Matrix, usize, usize) {
let n = y.len();
let p = x_vars.len();
let mut x_data = vec![1.0; n * (p + 1)]; for (j, x_var) in x_vars.iter().enumerate() {
for (i, &val) in x_var.iter().enumerate() {
x_data[i * (p + 1) + j + 1] = val;
}
}
(linalg::Matrix::new(n, p + 1, x_data), n, p)
}