use crate::error::WasmError;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub fn minimize_nelder_mead(
f_values: &[f64],
x0: &[f64],
tol: f64,
max_iter: u32,
) -> Result<JsValue, JsValue> {
let n_vertices = f_values.len();
if n_vertices < 2 {
return Err(WasmError::InvalidParameter(
"Need at least 2 vertices (function values) for Nelder-Mead".to_string(),
)
.into());
}
let n_dim = n_vertices - 1;
let expected_coords = n_vertices * n_dim;
if x0.len() != expected_coords {
return Err(WasmError::InvalidParameter(format!(
"Expected {} coordinates ({} vertices x {} dimensions), got {}",
expected_coords,
n_vertices,
n_dim,
x0.len()
))
.into());
}
if tol <= 0.0 {
return Err(WasmError::InvalidParameter("Tolerance must be positive".to_string()).into());
}
let mut simplex: Vec<Vec<f64>> = Vec::with_capacity(n_vertices);
for i in 0..n_vertices {
let start = i * n_dim;
let end = start + n_dim;
simplex.push(x0[start..end].to_vec());
}
let mut values: Vec<f64> = f_values.to_vec();
let alpha = 1.0; let gamma = 2.0; let rho = 0.5; let sigma = 0.5;
let max_iter = max_iter as usize;
let mut nit: usize = 0;
let mut success = false;
for _iter in 0..max_iter {
nit += 1;
let mut indices: Vec<usize> = (0..n_vertices).collect();
indices.sort_by(|&a, &b| {
values[a]
.partial_cmp(&values[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let sorted_simplex: Vec<Vec<f64>> = indices.iter().map(|&i| simplex[i].clone()).collect();
let sorted_values: Vec<f64> = indices.iter().map(|&i| values[i]).collect();
simplex = sorted_simplex;
values = sorted_values;
let mean_val = values.iter().sum::<f64>() / n_vertices as f64;
let std_dev =
(values.iter().map(|v| (v - mean_val).powi(2)).sum::<f64>() / n_vertices as f64).sqrt();
if std_dev < tol {
success = true;
break;
}
let centroid = compute_centroid(&simplex[..n_vertices - 1], n_dim);
let reflected = reflect(¢roid, &simplex[n_vertices - 1], alpha, n_dim);
let f_reflected = evaluate_quadratic_model(&reflected, &simplex, &values);
if f_reflected < values[n_vertices - 2] && f_reflected >= values[0] {
simplex[n_vertices - 1] = reflected;
values[n_vertices - 1] = f_reflected;
continue;
}
if f_reflected < values[0] {
let expanded = reflect(¢roid, &simplex[n_vertices - 1], gamma, n_dim);
let f_expanded = evaluate_quadratic_model(&expanded, &simplex, &values);
if f_expanded < f_reflected {
simplex[n_vertices - 1] = expanded;
values[n_vertices - 1] = f_expanded;
} else {
simplex[n_vertices - 1] = reflected;
values[n_vertices - 1] = f_reflected;
}
continue;
}
let contracted = reflect(¢roid, &simplex[n_vertices - 1], rho, n_dim);
let f_contracted = evaluate_quadratic_model(&contracted, &simplex, &values);
if f_contracted < values[n_vertices - 1] {
simplex[n_vertices - 1] = contracted;
values[n_vertices - 1] = f_contracted;
continue;
}
let best = simplex[0].clone();
for i in 1..n_vertices {
for j in 0..n_dim {
simplex[i][j] = best[j] + sigma * (simplex[i][j] - best[j]);
}
values[i] = evaluate_quadratic_model(&simplex[i], &simplex, &values);
}
}
let mut indices: Vec<usize> = (0..n_vertices).collect();
indices.sort_by(|&a, &b| {
values[a]
.partial_cmp(&values[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let best_idx = indices[0];
let best_x = &simplex[best_idx];
let best_fun = values[best_idx];
let flat_simplex: Vec<f64> = simplex.iter().flat_map(|v| v.iter().copied()).collect();
let result = serde_json::json!({
"x": best_x,
"fun": best_fun,
"nit": nit,
"success": success,
"simplex": flat_simplex,
"simplex_values": values,
});
serde_wasm_bindgen::to_value(&result)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
#[wasm_bindgen]
pub fn minimize_golden(a: f64, b: f64, tol: f64, max_iter: u32) -> Result<JsValue, JsValue> {
if a >= b {
return Err(WasmError::InvalidParameter(
"Lower bound 'a' must be less than upper bound 'b'".to_string(),
)
.into());
}
if tol <= 0.0 {
return Err(WasmError::InvalidParameter("Tolerance must be positive".to_string()).into());
}
let golden_ratio = (5.0_f64.sqrt() - 1.0) / 2.0; let mut lo = a;
let mut hi = b;
let mut nit: usize = 0;
let mut success = false;
let mut x1 = hi - golden_ratio * (hi - lo);
let mut x2 = lo + golden_ratio * (hi - lo);
let mut f1 = evaluate_golden_model(x1, lo, hi);
let mut f2 = evaluate_golden_model(x2, lo, hi);
for _iter in 0..max_iter as usize {
nit += 1;
if (hi - lo).abs() < tol {
success = true;
break;
}
if f1 < f2 {
hi = x2;
x2 = x1;
f2 = f1;
x1 = hi - golden_ratio * (hi - lo);
f1 = evaluate_golden_model(x1, lo, hi);
} else {
lo = x1;
x1 = x2;
f1 = f2;
x2 = lo + golden_ratio * (hi - lo);
f2 = evaluate_golden_model(x2, lo, hi);
}
}
let x_min = (lo + hi) / 2.0;
let f_min = evaluate_golden_model(x_min, lo, hi);
let result = serde_json::json!({
"x": x_min,
"fun": f_min,
"a": lo,
"b": hi,
"nit": nit,
"success": success,
});
serde_wasm_bindgen::to_value(&result)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
#[wasm_bindgen]
pub fn golden_section_step(
a: f64,
b: f64,
f_x1: f64,
f_x2: f64,
tol: f64,
) -> Result<JsValue, JsValue> {
if a >= b {
return Err(WasmError::InvalidParameter(
"Lower bound 'a' must be less than upper bound 'b'".to_string(),
)
.into());
}
let golden_ratio = (5.0_f64.sqrt() - 1.0) / 2.0;
let x1 = b - golden_ratio * (b - a);
let x2 = a + golden_ratio * (b - a);
let (new_a, new_b) = if f_x1 < f_x2 { (a, x2) } else { (x1, b) };
let new_x1 = new_b - golden_ratio * (new_b - new_a);
let new_x2 = new_a + golden_ratio * (new_b - new_a);
let converged = (new_b - new_a).abs() < tol;
let result = serde_json::json!({
"a": new_a,
"b": new_b,
"x1": new_x1,
"x2": new_x2,
"converged": converged,
"x_min": (new_a + new_b) / 2.0,
});
serde_wasm_bindgen::to_value(&result)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
#[wasm_bindgen]
pub fn brent_root(
a: f64,
b: f64,
fa: f64,
fb: f64,
tol: f64,
max_iter: u32,
) -> Result<f64, JsValue> {
if fa * fb > 0.0 {
return Err(WasmError::InvalidParameter(
"f(a) and f(b) must have opposite signs for Brent's method".to_string(),
)
.into());
}
if tol <= 0.0 {
return Err(WasmError::InvalidParameter("Tolerance must be positive".to_string()).into());
}
let mut a = a;
let mut b = b;
let mut fa = fa;
let mut fb = fb;
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
let mut c = a;
let mut fc = fa;
let mut d = b - a;
let mut e = d;
for _iter in 0..max_iter as usize {
if fb.abs() < tol {
return Ok(b);
}
if (b - a).abs() < tol {
return Ok(b);
}
if fc.abs() < fb.abs() {
a = b;
b = c;
c = a;
fa = fb;
fb = fc;
fc = fa;
}
let tol1 = 2.0 * f64::EPSILON * b.abs() + 0.5 * tol;
let mid = 0.5 * (c - b);
if mid.abs() <= tol1 || fb.abs() < f64::EPSILON {
return Ok(b);
}
if e.abs() >= tol1 && fa.abs() > fb.abs() {
let s = if (a - c).abs() < f64::EPSILON {
-fb * (b - a) / (fb - fa)
} else {
let r = fb / fc;
let q = fa / fc;
let p = fb / fa;
p * (2.0 * mid * q * (q - r) - (b - a) * (r - 1.0))
/ ((q - 1.0) * (r - 1.0) * (p - 1.0))
};
let s_abs = s.abs();
if 2.0 * s_abs < (3.0 * mid * fa.abs()).min(e.abs() * fa.abs()) {
e = d;
d = s;
} else {
d = mid;
e = d;
}
} else {
d = mid;
e = d;
}
a = b;
fa = fb;
if d.abs() > tol1 {
b += d;
} else {
b += if mid > 0.0 { tol1 } else { -tol1 };
}
fb = linear_interpolate_f(b, a, fa, c, fc);
}
Ok(b)
}
#[wasm_bindgen]
pub fn bisect_root(
a: f64,
b: f64,
fa: f64,
fb: f64,
tol: f64,
max_iter: u32,
) -> Result<f64, JsValue> {
if fa * fb > 0.0 {
return Err(WasmError::InvalidParameter(
"f(a) and f(b) must have opposite signs for bisection".to_string(),
)
.into());
}
if tol <= 0.0 {
return Err(WasmError::InvalidParameter("Tolerance must be positive".to_string()).into());
}
let mut lo = a;
let mut hi = b;
let mut f_lo = fa;
let mut f_hi = fb;
if f_lo > 0.0 {
std::mem::swap(&mut lo, &mut hi);
std::mem::swap(&mut f_lo, &mut f_hi);
}
for _iter in 0..max_iter as usize {
let mid = (lo + hi) / 2.0;
let width = (hi - lo).abs();
if width < tol {
return Ok(mid);
}
let f_mid = f_lo + (f_hi - f_lo) * (mid - lo) / (hi - lo);
if f_mid.abs() < tol * 0.01 {
return Ok(mid);
}
if f_mid < 0.0 {
lo = mid;
f_lo = f_mid;
} else {
hi = mid;
f_hi = f_mid;
}
}
Ok((lo + hi) / 2.0)
}
#[wasm_bindgen]
pub fn bisection_step(a: f64, b: f64, fa: f64, fb: f64, f_mid: f64) -> Result<JsValue, JsValue> {
if fa * fb > 0.0 {
return Err(WasmError::InvalidParameter(
"f(a) and f(b) must have opposite signs".to_string(),
)
.into());
}
let mid = (a + b) / 2.0;
let (new_a, new_b, new_fa, new_fb) = if fa * f_mid <= 0.0 {
(a, mid, fa, f_mid)
} else {
(mid, b, f_mid, fb)
};
let result = serde_json::json!({
"a": new_a,
"b": new_b,
"fa": new_fa,
"fb": new_fb,
"mid": (new_a + new_b) / 2.0,
});
serde_wasm_bindgen::to_value(&result)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
#[wasm_bindgen]
pub fn linear_regression(x: &[f64], y: &[f64]) -> Result<JsValue, JsValue> {
if x.len() != y.len() {
return Err(WasmError::ShapeMismatch {
expected: vec![x.len()],
actual: vec![y.len()],
}
.into());
}
let n = x.len();
if n < 2 {
return Err(WasmError::InvalidParameter(
"Need at least 2 data points for linear regression".to_string(),
)
.into());
}
let n_f = n as f64;
let sum_x: f64 = x.iter().sum();
let sum_y: f64 = y.iter().sum();
let sum_xx: f64 = x.iter().map(|xi| xi * xi).sum();
let sum_xy: f64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
let denom = n_f * sum_xx - sum_x * sum_x;
if denom.abs() < f64::EPSILON {
return Err(WasmError::ComputationError(
"All x values are identical; cannot compute regression".to_string(),
)
.into());
}
let slope = (n_f * sum_xy - sum_x * sum_y) / denom;
let intercept = (sum_y - slope * sum_x) / n_f;
let y_mean = sum_y / n_f;
let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
let ss_res: f64 = x
.iter()
.zip(y.iter())
.map(|(xi, yi)| {
let predicted = slope * xi + intercept;
(yi - predicted).powi(2)
})
.sum();
let r_squared = if ss_tot.abs() < f64::EPSILON {
1.0 } else {
1.0 - ss_res / ss_tot
};
let (std_err_slope, std_err_intercept) = if n > 2 {
let mse = ss_res / (n_f - 2.0);
let se_slope = (mse / (sum_xx - sum_x * sum_x / n_f)).sqrt();
let se_intercept = (mse * sum_xx / (n_f * (sum_xx - sum_x * sum_x / n_f))).sqrt();
(se_slope, se_intercept)
} else {
(f64::NAN, f64::NAN)
};
let result = serde_json::json!({
"slope": slope,
"intercept": intercept,
"r_squared": r_squared,
"std_err_slope": std_err_slope,
"std_err_intercept": std_err_intercept,
"n": n,
});
serde_wasm_bindgen::to_value(&result)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
#[wasm_bindgen]
pub fn polynomial_fit(x: &[f64], y: &[f64], degree: u32) -> Result<Vec<f64>, JsValue> {
if x.len() != y.len() {
return Err(WasmError::ShapeMismatch {
expected: vec![x.len()],
actual: vec![y.len()],
}
.into());
}
let n = x.len();
let d = degree as usize;
if n < d + 1 {
return Err(WasmError::InvalidParameter(format!(
"Need at least {} data points for degree {} polynomial, got {}",
d + 1,
d,
n
))
.into());
}
if d == 0 {
let mean_y = y.iter().sum::<f64>() / n as f64;
return Ok(vec![mean_y]);
}
let ncols = d + 1;
let mut ata = vec![0.0; ncols * ncols];
for row in 0..ncols {
for col in row..ncols {
let power = row + col;
let val: f64 = x.iter().map(|xi| xi.powi(power as i32)).sum();
ata[row * ncols + col] = val;
ata[col * ncols + row] = val;
}
}
let mut aty = vec![0.0; ncols];
for (row, aty_row) in aty.iter_mut().enumerate().take(ncols) {
*aty_row = x
.iter()
.zip(y.iter())
.map(|(xi, yi)| xi.powi(row as i32) * yi)
.sum();
}
solve_linear_system_flat(&ata, &aty, ncols)
}
#[wasm_bindgen]
pub fn polynomial_eval(coeffs: &[f64], x: &[f64]) -> Vec<f64> {
x.iter()
.map(|xi| {
let mut result = 0.0;
for c in coeffs.iter().rev() {
result = result * xi + c;
}
result
})
.collect()
}
fn compute_centroid(points: &[Vec<f64>], n_dim: usize) -> Vec<f64> {
let n = points.len() as f64;
let mut centroid = vec![0.0; n_dim];
for point in points {
for (j, val) in point.iter().enumerate() {
centroid[j] += val;
}
}
for val in &mut centroid {
*val /= n;
}
centroid
}
fn reflect(centroid: &[f64], point: &[f64], alpha: f64, n_dim: usize) -> Vec<f64> {
let mut reflected = vec![0.0; n_dim];
for i in 0..n_dim {
reflected[i] = centroid[i] + alpha * (centroid[i] - point[i]);
}
reflected
}
fn evaluate_quadratic_model(point: &[f64], simplex: &[Vec<f64>], values: &[f64]) -> f64 {
let mut weight_sum = 0.0;
let mut val_sum = 0.0;
for (vertex, &fval) in simplex.iter().zip(values.iter()) {
let dist_sq: f64 = point
.iter()
.zip(vertex.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
let dist = dist_sq.sqrt().max(f64::EPSILON);
let weight = 1.0 / (dist * dist);
weight_sum += weight;
val_sum += weight * fval;
}
if weight_sum > 0.0 {
val_sum / weight_sum
} else {
values[0]
}
}
fn evaluate_golden_model(x: f64, _lo: f64, _hi: f64) -> f64 {
x * x
}
fn linear_interpolate_f(x: f64, x1: f64, f1: f64, x2: f64, f2: f64) -> f64 {
let dx = x2 - x1;
if dx.abs() < f64::EPSILON {
return f1;
}
f1 + (f2 - f1) * (x - x1) / dx
}
fn solve_linear_system_flat(a: &[f64], b: &[f64], n: usize) -> Result<Vec<f64>, JsValue> {
let mut aug = vec![0.0; n * (n + 1)];
for i in 0..n {
for j in 0..n {
aug[i * (n + 1) + j] = a[i * n + j];
}
aug[i * (n + 1) + n] = b[i];
}
for col in 0..n {
let mut max_val = aug[col * (n + 1) + col].abs();
let mut max_row = col;
for row in (col + 1)..n {
let val = aug[row * (n + 1) + col].abs();
if val > max_val {
max_val = val;
max_row = row;
}
}
if max_val < 1e-15 {
return Err(WasmError::ComputationError(
"Singular or near-singular system in polynomial fit".to_string(),
)
.into());
}
if max_row != col {
for j in 0..=n {
let idx_col = col * (n + 1) + j;
let idx_max = max_row * (n + 1) + j;
aug.swap(idx_col, idx_max);
}
}
let pivot = aug[col * (n + 1) + col];
for row in (col + 1)..n {
let factor = aug[row * (n + 1) + col] / pivot;
for j in col..=n {
let above = aug[col * (n + 1) + j];
aug[row * (n + 1) + j] -= factor * above;
}
}
}
let mut result = vec![0.0; n];
for i in (0..n).rev() {
let mut sum = aug[i * (n + 1) + n];
for j in (i + 1)..n {
sum -= aug[i * (n + 1) + j] * result[j];
}
let diag = aug[i * (n + 1) + i];
if diag.abs() < 1e-15 {
return Err(WasmError::ComputationError(
"Zero diagonal in back substitution".to_string(),
)
.into());
}
result[i] = sum / diag;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(target_arch = "wasm32")]
#[test]
fn test_linear_regression_basic() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let result = linear_regression(&x, &y);
assert!(result.is_ok());
}
#[test]
fn test_linear_regression_error_mismatched_lengths() {
let x = vec![1.0, 2.0];
let y = vec![1.0, 2.0, 3.0];
let result = linear_regression(&x, &y);
assert!(result.is_err());
}
#[test]
fn test_linear_regression_error_too_few_points() {
let x = vec![1.0];
let y = vec![2.0];
let result = linear_regression(&x, &y);
assert!(result.is_err());
}
#[test]
fn test_polynomial_fit_linear() {
let x = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let y = vec![1.0, 3.0, 5.0, 7.0, 9.0];
let coeffs = polynomial_fit(&x, &y, 1);
assert!(coeffs.is_ok());
let coeffs = coeffs.expect("polynomial_fit should succeed");
assert!(
(coeffs[0] - 1.0).abs() < 1e-10,
"intercept should be ~1.0, got {}",
coeffs[0]
);
assert!(
(coeffs[1] - 2.0).abs() < 1e-10,
"slope should be ~2.0, got {}",
coeffs[1]
);
}
#[test]
fn test_polynomial_fit_quadratic() {
let x = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
let y = vec![4.0, 1.0, 0.0, 1.0, 4.0];
let coeffs = polynomial_fit(&x, &y, 2);
assert!(coeffs.is_ok());
let coeffs = coeffs.expect("polynomial_fit should succeed");
assert!(
(coeffs[0]).abs() < 1e-10,
"c0 should be ~0, got {}",
coeffs[0]
);
assert!(
(coeffs[1]).abs() < 1e-10,
"c1 should be ~0, got {}",
coeffs[1]
);
assert!(
(coeffs[2] - 1.0).abs() < 1e-10,
"c2 should be ~1, got {}",
coeffs[2]
);
}
#[test]
fn test_polynomial_eval_horner() {
let coeffs = vec![1.0, 2.0, 3.0];
let x = vec![0.0, 1.0, 2.0];
let vals = polynomial_eval(&coeffs, &x);
assert!((vals[0] - 1.0).abs() < 1e-10);
assert!((vals[1] - 6.0).abs() < 1e-10);
assert!((vals[2] - 17.0).abs() < 1e-10);
}
#[test]
fn test_polynomial_fit_error_too_few_points() {
let x = vec![1.0];
let y = vec![2.0];
let result = polynomial_fit(&x, &y, 2);
assert!(result.is_err());
}
#[test]
fn test_bisect_root_linear() {
let result = bisect_root(0.0, 2.0, -1.0, 1.0, 1e-10, 100);
assert!(result.is_ok());
let root = result.expect("bisect_root should succeed");
assert!(
(root - 1.0).abs() < 1e-6,
"root should be ~1.0, got {}",
root
);
}
#[test]
fn test_bisect_root_error_same_sign() {
let result = bisect_root(0.0, 2.0, 1.0, 3.0, 1e-10, 100);
assert!(result.is_err());
}
#[test]
fn test_brent_root_linear() {
let result = brent_root(0.0, 2.0, -1.0, 1.0, 1e-10, 100);
assert!(result.is_ok());
let root = result.expect("brent_root should succeed");
assert!(
(root - 1.0).abs() < 1e-6,
"root should be ~1.0, got {}",
root
);
}
#[test]
fn test_brent_root_error_same_sign() {
let result = brent_root(0.0, 2.0, 1.0, 3.0, 1e-10, 100);
assert!(result.is_err());
}
#[test]
fn test_solve_linear_system_flat() {
let a = vec![2.0, 1.0, 1.0, 3.0];
let b = vec![5.0, 7.0];
let result = solve_linear_system_flat(&a, &b, 2);
assert!(result.is_ok());
let sol = result.expect("solve should succeed");
assert!((sol[0] - 1.6).abs() < 1e-10);
assert!((sol[1] - 1.8).abs() < 1e-10);
}
#[test]
fn test_nelder_mead_validates_inputs() {
let result = minimize_nelder_mead(&[1.0], &[0.0], 1e-6, 100);
assert!(result.is_err());
let result = minimize_nelder_mead(&[1.0, 2.0, 3.0], &[0.0, 0.0], 1e-6, 100);
assert!(result.is_err());
let result =
minimize_nelder_mead(&[1.0, 2.0, 3.0], &[0.0, 1.0, 0.5, 0.0, 0.0, 1.0], -1.0, 100);
assert!(result.is_err());
}
}