use crate::error::{js_serialize_safe, js_string_safe};
use crate::stats_advanced::solve_system_f64;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub fn wasm_matrix_solve(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
if n == 0 || a.len() != n * n || b.len() != n {
return Vec::new();
}
solve_system_f64(a, b, n).unwrap_or_default()
}
#[wasm_bindgen]
pub fn wasm_svd(data: &[f64], rows: usize, cols: usize) -> JsValue {
if rows == 0 || cols == 0 {
return js_string_safe("Error: matrix dimensions must be positive");
}
if data.len() != rows * cols {
return js_string_safe(&format!(
"Error: expected {} elements for {}×{} matrix, got {}",
rows * cols,
rows,
cols,
data.len()
));
}
let (u_flat, s, vt_flat) = bidiag_svd(data, rows, cols);
let k = s.len();
let tol = s[0] * (rows.max(cols) as f64) * f64::EPSILON;
let rank = s.iter().filter(|&&sv| sv > tol).count();
let u_2d: Vec<Vec<f64>> = (0..rows)
.map(|i| u_flat[i * k..(i + 1) * k].to_vec())
.collect();
let vt_2d: Vec<Vec<f64>> = (0..k)
.map(|i| vt_flat[i * cols..(i + 1) * cols].to_vec())
.collect();
let result = serde_json::json!({
"U": u_2d,
"s": s,
"Vt": vt_2d,
"rank": rank,
"rows": rows,
"cols": cols,
});
js_serialize_safe(&result)
}
fn bidiag_svd(data: &[f64], rows: usize, cols: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let k = rows.min(cols);
let mut a = data.to_vec();
let mut u: Vec<f64> = {
let mut m = vec![0.0_f64; rows * rows];
for i in 0..rows {
m[i * rows + i] = 1.0;
}
m
};
let mut v: Vec<f64> = {
let mut m = vec![0.0_f64; cols * cols];
for i in 0..cols {
m[i * cols + i] = 1.0;
}
m
};
let mut diag = vec![0.0_f64; k];
let mut super_diag = vec![0.0_f64; k.saturating_sub(1)];
for i in 0..k {
{
let mut norm_sq = 0.0_f64;
for r in i..rows {
norm_sq += a[r * cols + i].powi(2);
}
let norm = norm_sq.sqrt();
if norm > 1e-14 {
let sign = if a[i * cols + i] >= 0.0 { 1.0 } else { -1.0 };
let u_norm = norm * sign;
a[i * cols + i] += u_norm;
let beta = 1.0 / (norm * (norm + a[i * cols + i] - u_norm));
for c in (i + 1)..cols {
let mut dot = 0.0_f64;
for r in i..rows {
dot += a[r * cols + i] * a[r * cols + c];
}
dot *= beta;
for r in i..rows {
a[r * cols + c] -= dot * a[r * cols + i];
}
}
let beta_u = 1.0 / (a[i * cols + i].powi(2) / 2.0 + norm_sq - a[i * cols + i].powi(2) / 2.0 + 1e-300);
let _ = beta_u; let factor = a[i * cols + i];
let beta2 = 2.0 / (factor * factor + (i + 1..rows).map(|r| a[r * cols + i].powi(2)).sum::<f64>() + 1e-300);
for c in 0..rows {
let mut dot = 0.0_f64;
for r in i..rows {
dot += u[r * rows + c] * a[r * cols + i];
}
dot *= beta2;
for r in i..rows {
u[r * rows + c] -= dot * a[r * cols + i];
}
}
diag[i] = -u_norm;
a[i * cols + i] = -u_norm;
for r in (i + 1)..rows {
a[r * cols + i] = 0.0;
}
} else {
diag[i] = a[i * cols + i];
}
}
if i + 1 < cols {
let j = i + 1;
let mut norm_sq = 0.0_f64;
for c in j..cols {
norm_sq += a[i * cols + c].powi(2);
}
let norm = norm_sq.sqrt();
if norm > 1e-14 {
let sign = if a[i * cols + j] >= 0.0 { 1.0 } else { -1.0 };
let v_norm = norm * sign;
a[i * cols + j] += v_norm;
let beta = 1.0 / (norm * (norm + a[i * cols + j] - v_norm));
for r in (i + 1)..rows {
let mut dot = 0.0_f64;
for c in j..cols {
dot += a[i * cols + c] * a[r * cols + c];
}
dot *= beta;
for c in j..cols {
a[r * cols + c] -= dot * a[i * cols + c];
}
}
let factor = a[i * cols + j];
let beta2 = 2.0 / (factor * factor + (j + 1..cols).map(|c| a[i * cols + c].powi(2)).sum::<f64>() + 1e-300);
for r in 0..cols {
let mut dot = 0.0_f64;
for c in j..cols {
dot += v[r * cols + c] * a[i * cols + c];
}
dot *= beta2;
for c in j..cols {
v[r * cols + c] -= dot * a[i * cols + c];
}
}
if j < super_diag.len() + 1 && i < super_diag.len() {
super_diag[i] = -v_norm;
}
a[i * cols + j] = -v_norm;
for c in (j + 1)..cols {
a[i * cols + c] = 0.0;
}
} else if i < super_diag.len() {
super_diag[i] = a[i * cols + j];
}
}
}
golub_kahan_svd_step(&mut diag, &mut super_diag, &mut u, &mut v, rows, cols, k);
for i in 0..k {
if diag[i] < 0.0 {
diag[i] = -diag[i];
for r in 0..rows {
u[r * rows + i] = -u[r * rows + i];
}
}
}
let mut order: Vec<usize> = (0..k).collect();
order.sort_by(|&a_idx, &b_idx| {
diag[b_idx]
.partial_cmp(&diag[a_idx])
.unwrap_or(std::cmp::Ordering::Equal)
});
let s_sorted: Vec<f64> = order.iter().map(|&i| diag[i]).collect();
let mut u_thin = vec![0.0_f64; rows * k];
for (new_col, &old_col) in order.iter().enumerate() {
for r in 0..rows {
u_thin[r * k + new_col] = u[r * rows + old_col];
}
}
let mut vt_thin = vec![0.0_f64; k * cols];
for (new_row, &old_col) in order.iter().enumerate() {
for c in 0..cols {
vt_thin[new_row * cols + c] = v[c * cols + old_col];
}
}
(u_thin, s_sorted, vt_thin)
}
fn golub_kahan_svd_step(
diag: &mut Vec<f64>,
super_diag: &mut Vec<f64>,
u: &mut Vec<f64>,
v: &mut Vec<f64>,
rows: usize,
cols: usize,
k: usize,
) {
let max_passes = 30 * k + 6;
let tol = 1e-14_f64;
for _ in 0..max_passes {
let mut converged = true;
for i in 0..super_diag.len() {
if super_diag[i].abs() > tol * (diag[i].abs() + diag[i + 1].abs()) {
converged = false;
} else {
super_diag[i] = 0.0;
}
}
if converged {
break;
}
let q_len = super_diag.len();
let mut q = q_len;
while q > 0 && super_diag[q - 1] == 0.0 {
q -= 1;
}
if q == 0 {
break;
}
let mut p = q;
while p > 0 && diag[p - 1].abs() > tol && super_diag[p - 1] != 0.0 {
p -= 1;
}
let d_q = diag[q];
let d_q1 = if q > 0 { diag[q - 1] } else { 0.0 };
let e = if q > 0 { super_diag[q - 1] } else { 0.0 };
let t11 = d_q1 * d_q1 + (if q >= 2 { super_diag[q - 2].powi(2) } else { 0.0 });
let t12 = d_q1 * e;
let t22 = d_q * d_q + e * e;
let mu = wilkinson_shift(t11, t12, t22);
let mut y = diag[p] * diag[p] - mu;
let mut z = diag[p] * super_diag[p].min(0.0 + if p < super_diag.len() { super_diag[p] } else { 0.0 });
for i in p..q {
let (cos_r, sin_r) = givens_rotation(y, z);
apply_givens_right_bidiag(diag, super_diag, i, cos_r, sin_r, q_len);
for r in 0..cols {
let a = v[r * cols + i];
let b = v[r * cols + i + 1];
v[r * cols + i] = cos_r * a + sin_r * b;
v[r * cols + i + 1] = -sin_r * a + cos_r * b;
}
y = diag[i];
z = super_diag[i];
let (cos_l, sin_l) = givens_rotation(y, z);
apply_givens_left_bidiag(diag, super_diag, i, cos_l, sin_l, q_len);
for r in 0..rows {
let a = u[r * rows + i];
let b = u[r * rows + i + 1];
u[r * rows + i] = cos_l * a + sin_l * b;
u[r * rows + i + 1] = -sin_l * a + cos_l * b;
}
if i + 1 < q {
y = diag[i + 1];
z = if i + 1 < super_diag.len() { super_diag[i + 1] } else { 0.0 };
}
}
}
}
fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
if b == 0.0 {
let c = if a >= 0.0 { 1.0 } else { -1.0 };
return (c, 0.0);
}
if a == 0.0 {
let s = if b >= 0.0 { 1.0 } else { -1.0 };
return (0.0, s);
}
let r = a.hypot(b);
(a / r, b / r)
}
fn apply_givens_right_bidiag(
diag: &mut Vec<f64>,
super_diag: &mut Vec<f64>,
i: usize,
c: f64,
s: f64,
_q_len: usize,
) {
let d_i = diag[i];
let e_i = if i < super_diag.len() { super_diag[i] } else { 0.0 };
let new_d = c * d_i + s * e_i;
let new_e = -s * d_i + c * e_i;
diag[i] = new_d;
if i < super_diag.len() {
super_diag[i] = new_e;
}
if i + 1 < diag.len() {
let d_next = diag[i + 1];
let old_e_minus1 = if i > 0 { super_diag[i - 1] } else { 0.0 };
if i > 0 {
super_diag[i - 1] = c * old_e_minus1;
}
let _ = d_next; }
}
fn apply_givens_left_bidiag(
diag: &mut Vec<f64>,
super_diag: &mut Vec<f64>,
i: usize,
c: f64,
s: f64,
q_len: usize,
) {
let d_i = diag[i];
let e_i = if i < super_diag.len() { super_diag[i] } else { 0.0 };
let new_d = c * d_i + s * e_i;
let new_e = -s * d_i + c * e_i;
diag[i] = new_d;
if i < super_diag.len() {
super_diag[i] = new_e;
}
if i + 1 < diag.len() {
let d_next = diag[i + 1];
diag[i + 1] = c * d_next; let _ = q_len;
}
}
fn wilkinson_shift(t11: f64, t12: f64, t22: f64) -> f64 {
let delta = (t11 - t22) / 2.0;
if delta == 0.0 {
return t22 - t12.abs();
}
let sign = if delta >= 0.0 { 1.0 } else { -1.0 };
t22 - sign * t12 * t12 / (delta.abs() + (delta * delta + t12 * t12).sqrt())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wasm_matrix_solve_basic() {
let a = [2.0_f64, 1.0, 1.0, 3.0];
let b = [5.0_f64, 10.0];
let x = wasm_matrix_solve(&a, &b, 2);
assert_eq!(x.len(), 2);
assert!((x[0] - 1.0).abs() < 1e-10, "x[0] = {}", x[0]);
assert!((x[1] - 3.0).abs() < 1e-10, "x[1] = {}", x[1]);
}
#[test]
fn test_wasm_matrix_solve_singular() {
let a = [1.0_f64, 2.0, 2.0, 4.0];
let b = [3.0_f64, 6.0];
let x = wasm_matrix_solve(&a, &b, 2);
let _ = x;
}
#[test]
fn test_wasm_matrix_solve_wrong_size() {
let a = [1.0_f64, 0.0, 0.0];
let b = [1.0_f64, 1.0];
let x = wasm_matrix_solve(&a, &b, 2);
assert!(x.is_empty());
}
#[cfg(target_arch = "wasm32")]
#[test]
fn test_wasm_svd_identity() {
let data = [1.0_f64, 0.0, 0.0, 1.0];
let result = wasm_svd(&data, 2, 2);
assert!(!result.is_null(), "SVD returned NULL");
assert!(!result.is_string(), "SVD returned error: {:?}", result);
}
#[test]
fn test_wasm_svd_wrong_size() {
let data = [1.0_f64, 2.0, 3.0];
let result = wasm_svd(&data, 2, 2); #[cfg(target_arch = "wasm32")]
assert!(result.is_string());
#[cfg(not(target_arch = "wasm32"))]
assert!(result.is_null());
}
#[test]
fn test_wasm_svd_zero_dimensions() {
let result = wasm_svd(&[], 0, 2);
#[cfg(target_arch = "wasm32")]
assert!(result.is_string());
#[cfg(not(target_arch = "wasm32"))]
assert!(result.is_null());
}
#[test]
fn test_givens_rotation_zero_b() {
let (c, s) = givens_rotation(3.0, 0.0);
assert!((c - 1.0).abs() < 1e-12);
assert!((s).abs() < 1e-12);
}
#[test]
fn test_givens_rotation_zero_a() {
let (c, s) = givens_rotation(0.0, 4.0);
assert!((c).abs() < 1e-12);
assert!((s - 1.0).abs() < 1e-12);
}
#[test]
fn test_wilkinson_shift_symmetric() {
let mu = wilkinson_shift(5.0, 2.0, 5.0);
assert!((mu - 3.0).abs() < 1e-10, "mu = {}", mu);
}
}