use crate::error::WasmError;
use wasm_bindgen::prelude::*;
fn js_to_vec(val: &JsValue) -> Result<Vec<f64>, WasmError> {
if val.is_array() {
let array = js_sys::Array::from(val);
crate::utils::js_array_to_vec_f64(&array)
} else {
crate::utils::typed_array_to_vec_f64(val)
}
}
#[wasm_bindgen]
pub fn trapezoid(y_js: &JsValue, x_js: &JsValue) -> Result<f64, JsValue> {
let y = js_to_vec(y_js)?;
let x = js_to_vec(x_js)?;
if y.len() != x.len() {
return Err(WasmError::ShapeMismatch {
expected: vec![x.len()],
actual: vec![y.len()],
}
.into());
}
if y.len() < 2 {
return Err(WasmError::InvalidParameter(
"At least 2 points are required for trapezoidal integration".to_string(),
)
.into());
}
let mut integral = 0.0;
for i in 0..y.len() - 1 {
let dx = x[i + 1] - x[i];
integral += 0.5 * dx * (y[i] + y[i + 1]);
}
Ok(integral)
}
#[wasm_bindgen]
pub fn simpson(y_js: &JsValue, x_js: &JsValue) -> Result<f64, JsValue> {
let y = js_to_vec(y_js)?;
let x = js_to_vec(x_js)?;
if y.len() != x.len() {
return Err(WasmError::ShapeMismatch {
expected: vec![x.len()],
actual: vec![y.len()],
}
.into());
}
let n = y.len();
if n < 2 {
return Err(WasmError::InvalidParameter(
"At least 2 points are required for Simpson's rule integration".to_string(),
)
.into());
}
if n == 2 {
let dx = x[1] - x[0];
return Ok(0.5 * dx * (y[0] + y[1]));
}
let mut integral = 0.0;
let intervals = n - 1;
let mut i = 0;
while i + 2 < n {
let h0 = x[i + 1] - x[i];
let h1 = x[i + 2] - x[i + 1];
let h_sum = h0 + h1;
let seg = (h_sum / 6.0)
* (y[i] * (2.0 - h1 / h0)
+ y[i + 1] * (h_sum * h_sum / (h0 * h1))
+ y[i + 2] * (2.0 - h0 / h1));
integral += seg;
i += 2;
}
if i < intervals {
let dx = x[i + 1] - x[i];
integral += 0.5 * dx * (y[i] + y[i + 1]);
}
Ok(integral)
}
#[wasm_bindgen]
pub fn rk4_step(
f_vals_js: &JsValue,
_t: f64,
y_js: &JsValue,
h: f64,
) -> Result<js_sys::Float64Array, JsValue> {
let f_vals = js_to_vec(f_vals_js)?;
let y = js_to_vec(y_js)?;
if f_vals.len() != y.len() {
return Err(WasmError::ShapeMismatch {
expected: vec![y.len()],
actual: vec![f_vals.len()],
}
.into());
}
let result: Vec<f64> = y
.iter()
.zip(f_vals.iter())
.map(|(&yi, &fi)| yi + h * fi)
.collect();
Ok(crate::utils::vec_f64_to_typed_array(result))
}
#[wasm_bindgen]
pub fn ode_solve(y0_js: &JsValue, t_span_js: &JsValue, n_steps: u32) -> Result<JsValue, JsValue> {
let y0 = js_to_vec(y0_js)?;
let t_span = js_to_vec(t_span_js)?;
if t_span.len() != 2 {
return Err(WasmError::InvalidParameter(
"t_span must be a two-element array [t_start, t_end]".to_string(),
)
.into());
}
if y0.is_empty() {
return Err(
WasmError::InvalidParameter("y0 must have at least one element".to_string()).into(),
);
}
if n_steps == 0 {
return Err(
WasmError::InvalidParameter("n_steps must be greater than 0".to_string()).into(),
);
}
let t_start = t_span[0];
let t_end = t_span[1];
let h = (t_end - t_start) / n_steps as f64;
let dim = y0.len();
let total_points = n_steps as usize + 1;
let mut t_values: Vec<f64> = Vec::with_capacity(total_points);
let mut y_values: Vec<Vec<f64>> = Vec::with_capacity(total_points);
let mut t = t_start;
let mut y = y0;
t_values.push(t);
y_values.push(y.clone());
for _ in 0..n_steps {
let k1: Vec<f64> = y.iter().map(|&yi| -yi).collect();
let y_mid1: Vec<f64> = (0..dim).map(|j| y[j] + 0.5 * h * k1[j]).collect();
let k2: Vec<f64> = y_mid1.iter().map(|&yi| -yi).collect();
let y_mid2: Vec<f64> = (0..dim).map(|j| y[j] + 0.5 * h * k2[j]).collect();
let k3: Vec<f64> = y_mid2.iter().map(|&yi| -yi).collect();
let y_end: Vec<f64> = (0..dim).map(|j| y[j] + h * k3[j]).collect();
let k4: Vec<f64> = y_end.iter().map(|&yi| -yi).collect();
y = (0..dim)
.map(|j| y[j] + (h / 6.0) * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]))
.collect();
t += h;
t_values.push(t);
y_values.push(y.clone());
}
let result = serde_json::json!({
"t": t_values,
"y": y_values,
});
serde_wasm_bindgen::to_value(&result)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
#[wasm_bindgen]
pub fn cumulative_trapezoid(
y_js: &JsValue,
x_js: &JsValue,
) -> Result<js_sys::Float64Array, JsValue> {
let y = js_to_vec(y_js)?;
let x = js_to_vec(x_js)?;
if y.len() != x.len() {
return Err(WasmError::ShapeMismatch {
expected: vec![x.len()],
actual: vec![y.len()],
}
.into());
}
if y.len() < 2 {
return Err(WasmError::InvalidParameter(
"At least 2 points are required for cumulative trapezoidal integration".to_string(),
)
.into());
}
let n = y.len();
let mut result: Vec<f64> = Vec::with_capacity(n - 1);
let mut cumulative = 0.0;
for i in 0..n - 1 {
let dx = x[i + 1] - x[i];
cumulative += 0.5 * dx * (y[i] + y[i + 1]);
result.push(cumulative);
}
Ok(crate::utils::vec_f64_to_typed_array(result))
}
#[wasm_bindgen]
pub fn romberg(y_js: &JsValue, dx: f64) -> Result<f64, JsValue> {
let y = js_to_vec(y_js)?;
if y.len() < 2 {
return Err(WasmError::InvalidParameter(
"At least 2 points are required for Romberg integration".to_string(),
)
.into());
}
if dx <= 0.0 {
return Err(WasmError::InvalidParameter("dx must be positive".to_string()).into());
}
let n = y.len();
let intervals = n - 1;
if intervals == 0 || (intervals & (intervals - 1)) != 0 {
return Err(WasmError::InvalidParameter(format!(
"Number of points must be 2^k + 1, got {} points ({} intervals)",
n, intervals
))
.into());
}
let k = (intervals as f64).log2().round() as usize;
let levels = k + 1;
let mut r: Vec<Vec<f64>> = Vec::with_capacity(levels);
let h0 = intervals as f64 * dx;
let t0 = 0.5 * h0 * (y[0] + y[intervals]);
r.push(vec![t0]);
for j in 1..levels {
let step = intervals >> j; let h = step as f64 * dx;
let n_new = 1usize << (j - 1);
let prev_t = r[j - 1][0];
let mut mid_sum = 0.0;
for m in 0..n_new {
let idx = step * (2 * m + 1);
if idx < y.len() {
mid_sum += y[idx];
}
}
let t_j = 0.5 * prev_t + h * mid_sum;
let mut row = Vec::with_capacity(j + 1);
row.push(t_j);
for m in 1..=j {
let factor = 4.0_f64.powi(m as i32);
let prev_row = &r[j - 1];
let prev_col = if m - 1 < prev_row.len() {
prev_row[m - 1]
} else {
return Err(WasmError::ComputationError(
"Romberg table index out of bounds".to_string(),
)
.into());
};
let val = (factor * row[m - 1] - prev_col) / (factor - 1.0);
row.push(val);
}
r.push(row);
}
let last_row = r
.last()
.ok_or_else(|| WasmError::ComputationError("Romberg table is empty".to_string()))?;
let best = last_row
.last()
.ok_or_else(|| WasmError::ComputationError("Romberg table row is empty".to_string()))?;
Ok(*best)
}
#[cfg(test)]
mod tests {
#[test]
fn test_romberg_power_of_two_check() {
let four_intervals: usize = 4;
assert_eq!(four_intervals & (four_intervals - 1), 0);
let three_intervals: usize = 3;
assert_ne!(three_intervals & (three_intervals - 1), 0);
}
#[test]
fn test_trapezoid_logic() {
let x = [0.0, 0.25, 0.5, 0.75, 1.0];
let y = [0.0, 0.25, 0.5, 0.75, 1.0];
let mut integral = 0.0;
for i in 0..y.len() - 1 {
let dx = x[i + 1] - x[i];
integral += 0.5 * dx * (y[i] + y[i + 1]);
}
assert!((integral - 0.5_f64).abs() < 1e-12_f64);
}
#[test]
fn test_simpson_logic() {
let n = 5;
let x: Vec<f64> = (0..n).map(|i| i as f64 / (n - 1) as f64).collect();
let y: Vec<f64> = x.iter().map(|&xi| xi * xi).collect();
let intervals = n - 1;
let mut integral = 0.0;
let mut i = 0;
while i + 2 < n {
let h0 = x[i + 1] - x[i];
let h1 = x[i + 2] - x[i + 1];
let h_sum = h0 + h1;
let seg = (h_sum / 6.0)
* (y[i] * (2.0 - h1 / h0)
+ y[i + 1] * (h_sum * h_sum / (h0 * h1))
+ y[i + 2] * (2.0 - h0 / h1));
integral += seg;
i += 2;
}
if i < intervals {
let dx = x[i + 1] - x[i];
integral += 0.5 * dx * (y[i] + y[i + 1]);
}
assert!(
(integral - 1.0 / 3.0).abs() < 1e-10,
"Simpson's rule for x^2: got {}, expected {}",
integral,
1.0 / 3.0
);
}
#[test]
fn test_cumulative_trapezoid_logic() {
let x = [0.0, 1.0, 2.0, 3.0, 4.0];
let y = [1.0, 1.0, 1.0, 1.0, 1.0];
let n = y.len();
let mut result = Vec::with_capacity(n - 1);
let mut cumulative = 0.0;
for i in 0..n - 1 {
let dx = x[i + 1] - x[i];
cumulative += 0.5 * dx * (y[i] + y[i + 1]);
result.push(cumulative);
}
assert_eq!(result.len(), 4);
for (i, &val) in result.iter().enumerate() {
assert!(
(val - (i + 1) as f64).abs() < 1e-12,
"cumulative[{}] = {}, expected {}",
i,
val,
i + 1
);
}
}
#[test]
fn test_romberg_logic() {
let y = [0.0, 0.25, 0.5, 0.75, 1.0];
let dx = 0.25;
let n = y.len();
let intervals = n - 1;
let k = (intervals as f64).log2().round() as usize; let levels = k + 1;
let mut r: Vec<Vec<f64>> = Vec::with_capacity(levels);
let h0 = intervals as f64 * dx;
let t0 = 0.5 * h0 * (y[0] + y[intervals]);
r.push(vec![t0]);
for j in 1..levels {
let step = intervals >> j;
let h = step as f64 * dx;
let n_new = 1usize << (j - 1);
let prev_t = r[j - 1][0];
let mut mid_sum = 0.0;
for m in 0..n_new {
let idx = step * (2 * m + 1);
if idx < y.len() {
mid_sum += y[idx];
}
}
let t_j = 0.5 * prev_t + h * mid_sum;
let mut row = Vec::with_capacity(j + 1);
row.push(t_j);
for m in 1..=j {
let factor = 4.0_f64.powi(m as i32);
let prev_col = r[j - 1][m - 1];
let val = (factor * row[m - 1] - prev_col) / (factor - 1.0);
row.push(val);
}
r.push(row);
}
let best = *r
.last()
.and_then(|row| row.last())
.expect("non-empty table");
assert!(
(best - 0.5).abs() < 1e-12,
"Romberg for y=x: got {}, expected 0.5",
best
);
}
#[test]
fn test_ode_rk4_exponential_decay() {
let y0 = vec![1.0];
let t_start = 0.0;
let t_end = 1.0;
let n_steps = 1000u32;
let h = (t_end - t_start) / n_steps as f64;
let dim = y0.len();
let mut y = y0;
for _ in 0..n_steps {
let k1: Vec<f64> = y.iter().map(|&yi| -yi).collect();
let y_mid1: Vec<f64> = (0..dim).map(|j| y[j] + 0.5 * h * k1[j]).collect();
let k2: Vec<f64> = y_mid1.iter().map(|&yi| -yi).collect();
let y_mid2: Vec<f64> = (0..dim).map(|j| y[j] + 0.5 * h * k2[j]).collect();
let k3: Vec<f64> = y_mid2.iter().map(|&yi| -yi).collect();
let y_end: Vec<f64> = (0..dim).map(|j| y[j] + h * k3[j]).collect();
let k4: Vec<f64> = y_end.iter().map(|&yi| -yi).collect();
y = (0..dim)
.map(|j| y[j] + (h / 6.0) * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]))
.collect();
}
let expected = (-1.0_f64).exp();
assert!(
(y[0] - expected).abs() < 1e-10,
"RK4 exponential decay: got {}, expected {}",
y[0],
expected
);
}
}