use crate::math_wasm::{Matrix, Vector};
use crate::optimized_solver::OptimizedConjugateGradientSolver;
use crate::solver_core::{ConjugateGradientSolver, SolverConfig};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use wasm_bindgen::prelude::*;
#[derive(Serialize, Deserialize, Clone)]
pub struct WasmSolverConfig {
pub max_iterations: usize,
pub tolerance: f64,
pub simd_enabled: bool,
pub stream_chunk_size: usize,
}
impl Default for WasmSolverConfig {
fn default() -> Self {
Self {
max_iterations: 1000,
tolerance: 1e-10,
simd_enabled: cfg!(target_feature = "simd128"),
stream_chunk_size: 100,
}
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct SolutionStep {
pub iteration: usize,
pub residual: f64,
pub timestamp: f64,
pub convergence: bool,
}
#[derive(Serialize, Deserialize)]
pub struct MemoryUsage {
pub used: usize,
pub capacity: usize,
}
#[wasm_bindgen]
pub struct WasmSublinearSolver {
config: WasmSolverConfig,
solver: OptimizedConjugateGradientSolver,
callbacks: HashMap<String, js_sys::Function>,
memory_usage: usize,
}
#[wasm_bindgen]
impl WasmSublinearSolver {
#[wasm_bindgen(constructor)]
pub fn new(config: JsValue) -> Result<WasmSublinearSolver, JsValue> {
crate::set_panic_hook();
let config: WasmSolverConfig = if config.is_undefined() {
WasmSolverConfig::default()
} else {
serde_wasm_bindgen::from_value(config)
.map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?
};
let solver_config = SolverConfig {
max_iterations: config.max_iterations,
tolerance: config.tolerance,
};
#[cfg(feature = "std")]
let solver = if config.simd_enabled {
OptimizedConjugateGradientSolver::new_parallel(solver_config)
} else {
OptimizedConjugateGradientSolver::new(solver_config)
};
#[cfg(not(feature = "std"))]
let solver = OptimizedConjugateGradientSolver::new(solver_config);
Ok(WasmSublinearSolver {
config,
solver,
callbacks: HashMap::new(),
memory_usage: 0,
})
}
#[wasm_bindgen]
pub fn solve(
&mut self,
matrix_data: &[f64],
matrix_rows: usize,
matrix_cols: usize,
vector_data: &[f64],
) -> Result<Vec<f64>, JsValue> {
if matrix_data.len() != matrix_rows * matrix_cols {
return Err(JsValue::from_str("Matrix dimensions mismatch"));
}
if vector_data.len() != matrix_rows {
return Err(JsValue::from_str("Vector size mismatch"));
}
let matrix = Matrix::from_slice(matrix_data, matrix_rows, matrix_cols);
let vector = Vector::from_slice(vector_data);
self.memory_usage = matrix_data.len() * 8 + vector_data.len() * 8;
match self.solver.solve(&matrix, &vector) {
Ok(solution) => Ok(solution.data().to_vec()),
Err(e) => Err(JsValue::from_str(&format!("Solver error: {}", e))),
}
}
#[wasm_bindgen]
pub fn solve_stream(
&mut self,
matrix_data: &[f64],
matrix_rows: usize,
matrix_cols: usize,
vector_data: &[f64],
progress_callback: &js_sys::Function,
) -> Result<Vec<f64>, JsValue> {
if matrix_data.len() != matrix_rows * matrix_cols {
return Err(JsValue::from_str("Matrix dimensions mismatch"));
}
if vector_data.len() != matrix_rows {
return Err(JsValue::from_str("Vector size mismatch"));
}
let matrix = Matrix::from_slice(matrix_data, matrix_rows, matrix_cols);
let vector = Vector::from_slice(vector_data);
self.memory_usage = matrix_data.len() * 8 + vector_data.len() * 8;
let mut solution = None;
let chunk_size = self.config.stream_chunk_size;
let result = self
.solver
.solve_with_callback(&matrix, &vector, chunk_size, |step_data| {
let timestamp = js_sys::Date::now();
let step = SolutionStep {
iteration: step_data.iteration,
residual: step_data.residual,
timestamp,
convergence: step_data.converged,
};
let step_js = serde_wasm_bindgen::to_value(&step).unwrap();
let _ = progress_callback.call1(&JsValue::NULL, &step_js);
if step_data.converged {
solution = Some(step_data.solution.clone());
}
});
match result {
Ok(final_solution) => Ok(final_solution.data().to_vec()),
Err(e) => Err(JsValue::from_str(&format!("Streaming solver error: {}", e))),
}
}
#[wasm_bindgen]
pub fn solve_batch(&mut self, batch_data: JsValue) -> Result<JsValue, JsValue> {
#[derive(Deserialize)]
struct BatchRequest {
id: String,
matrix_data: Vec<f64>,
matrix_rows: usize,
matrix_cols: usize,
vector_data: Vec<f64>,
}
#[derive(Serialize)]
struct BatchResult {
id: String,
solution: Vec<f64>,
iterations: usize,
error: Option<String>,
}
let requests: Vec<BatchRequest> = serde_wasm_bindgen::from_value(batch_data)
.map_err(|e| JsValue::from_str(&format!("Invalid batch data: {}", e)))?;
let mut results = Vec::new();
for request in requests {
let result = match self.solve(
&request.matrix_data,
request.matrix_rows,
request.matrix_cols,
&request.vector_data,
) {
Ok(solution) => BatchResult {
id: request.id,
solution,
iterations: self.solver.get_last_iteration_count(),
error: None,
},
Err(e) => BatchResult {
id: request.id,
solution: Vec::new(),
iterations: 0,
error: Some(format!("{:?}", e)),
},
};
results.push(result);
}
serde_wasm_bindgen::to_value(&results)
.map_err(|e| JsValue::from_str(&format!("Failed to serialize results: {}", e)))
}
#[wasm_bindgen(getter)]
pub fn memory_usage(&self) -> JsValue {
let usage = MemoryUsage {
used: self.memory_usage,
capacity: self.memory_usage * 2, };
serde_wasm_bindgen::to_value(&usage).unwrap()
}
#[wasm_bindgen]
pub fn get_config(&self) -> JsValue {
serde_wasm_bindgen::to_value(&self.config).unwrap()
}
#[wasm_bindgen]
pub fn dispose(&mut self) {
self.callbacks.clear();
self.memory_usage = 0;
}
}
#[wasm_bindgen]
pub struct MatrixView {
data: Vec<f64>,
rows: usize,
cols: usize,
}
#[wasm_bindgen]
impl MatrixView {
#[wasm_bindgen(constructor)]
pub fn new(rows: usize, cols: usize) -> MatrixView {
MatrixView {
data: vec![0.0; rows * cols],
rows,
cols,
}
}
#[wasm_bindgen(getter)]
pub fn data(&self) -> *const f64 {
self.data.as_ptr()
}
#[wasm_bindgen(getter)]
pub fn length(&self) -> usize {
self.data.len()
}
#[wasm_bindgen(getter)]
pub fn rows(&self) -> usize {
self.rows
}
#[wasm_bindgen(getter)]
pub fn cols(&self) -> usize {
self.cols
}
#[wasm_bindgen]
pub fn data_view(&self) -> js_sys::Float64Array {
unsafe { js_sys::Float64Array::view(&self.data) }
}
#[wasm_bindgen]
pub fn set_data(&mut self, data: &[f64]) -> Result<(), JsValue> {
if data.len() != self.data.len() {
return Err(JsValue::from_str("Data length mismatch"));
}
self.data.copy_from_slice(data);
Ok(())
}
#[wasm_bindgen]
pub fn get_element(&self, row: usize, col: usize) -> Result<f64, JsValue> {
if row >= self.rows || col >= self.cols {
return Err(JsValue::from_str("Index out of bounds"));
}
Ok(self.data[row * self.cols + col])
}
#[wasm_bindgen]
pub fn set_element(&mut self, row: usize, col: usize, value: f64) -> Result<(), JsValue> {
if row >= self.rows || col >= self.cols {
return Err(JsValue::from_str("Index out of bounds"));
}
self.data[row * self.cols + col] = value;
Ok(())
}
}
#[wasm_bindgen]
pub fn allocate_matrix(rows: usize, cols: usize) -> *mut f64 {
let size = match rows.checked_mul(cols) {
Some(s) => s,
None => return core::ptr::null_mut(),
};
let layout = match std::alloc::Layout::array::<f64>(size) {
Ok(l) => l,
Err(_) => return core::ptr::null_mut(),
};
unsafe { std::alloc::alloc(layout) as *mut f64 }
}
#[wasm_bindgen]
pub fn deallocate_matrix(ptr: *mut f64, rows: usize, cols: usize) {
if ptr.is_null() {
return;
}
let size = match rows.checked_mul(cols) {
Some(s) => s,
None => return,
};
let layout = match std::alloc::Layout::array::<f64>(size) {
Ok(l) => l,
Err(_) => return,
};
unsafe { std::alloc::dealloc(ptr as *mut u8, layout) }
}
#[wasm_bindgen]
pub fn benchmark_matrix_multiply(size: usize) -> f64 {
let start = js_sys::Date::now();
let matrix_a = Matrix::identity(size);
let matrix_b = Matrix::identity(size);
let _result = matrix_a.multiply(&matrix_b);
js_sys::Date::now() - start
}
#[wasm_bindgen]
pub fn get_wasm_memory_usage() -> usize {
#[cfg(target_arch = "wasm32")]
{
use core::arch::wasm32;
unsafe {
wasm32::memory_size(0) * 65536 }
}
#[cfg(not(target_arch = "wasm32"))]
{
0
}
}
#[wasm_bindgen]
pub fn get_features() -> JsValue {
let mut features = Vec::new();
#[cfg(feature = "simd")]
features.push("simd");
#[cfg(feature = "wasm")]
features.push("wasm");
#[cfg(feature = "std")]
features.push("std");
serde_wasm_bindgen::to_value(&features).unwrap()
}
#[wasm_bindgen]
pub fn enable_simd() -> bool {
#[cfg(feature = "simd")]
{
#[cfg(target_arch = "wasm32")]
{
cfg!(target_feature = "simd128")
}
#[cfg(not(target_arch = "wasm32"))]
{
crate::has_simd_support()
}
}
#[cfg(not(feature = "simd"))]
{
false
}
}