sublinear 0.3.3

High-performance sublinear-time solver for asymmetric diagonally dominant systems
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
use crate::math_wasm::{Matrix, Vector};
use crate::optimized_solver::OptimizedConjugateGradientSolver;
use crate::solver_core::{ConjugateGradientSolver, SolverConfig};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use wasm_bindgen::prelude::*;

// Configuration structure for the solver
#[derive(Serialize, Deserialize, Clone)]
pub struct WasmSolverConfig {
    pub max_iterations: usize,
    pub tolerance: f64,
    pub simd_enabled: bool,
    pub stream_chunk_size: usize,
}

impl Default for WasmSolverConfig {
    fn default() -> Self {
        Self {
            max_iterations: 1000,
            tolerance: 1e-10,
            simd_enabled: cfg!(target_feature = "simd128"),
            stream_chunk_size: 100,
        }
    }
}

// Solution step for streaming interface
#[derive(Serialize, Deserialize, Clone)]
pub struct SolutionStep {
    pub iteration: usize,
    pub residual: f64,
    pub timestamp: f64,
    pub convergence: bool,
}

// Memory usage information
#[derive(Serialize, Deserialize)]
pub struct MemoryUsage {
    pub used: usize,
    pub capacity: usize,
}

// Main WASM solver interface
#[wasm_bindgen]
pub struct WasmSublinearSolver {
    config: WasmSolverConfig,
    solver: OptimizedConjugateGradientSolver,
    callbacks: HashMap<String, js_sys::Function>,
    memory_usage: usize,
}

#[wasm_bindgen]
impl WasmSublinearSolver {
    #[wasm_bindgen(constructor)]
    pub fn new(config: JsValue) -> Result<WasmSublinearSolver, JsValue> {
        crate::set_panic_hook();

        let config: WasmSolverConfig = if config.is_undefined() {
            WasmSolverConfig::default()
        } else {
            serde_wasm_bindgen::from_value(config)
                .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?
        };

        let solver_config = SolverConfig {
            max_iterations: config.max_iterations,
            tolerance: config.tolerance,
        };

        #[cfg(feature = "std")]
        let solver = if config.simd_enabled {
            OptimizedConjugateGradientSolver::new_parallel(solver_config)
        } else {
            OptimizedConjugateGradientSolver::new(solver_config)
        };
        #[cfg(not(feature = "std"))]
        let solver = OptimizedConjugateGradientSolver::new(solver_config);

        Ok(WasmSublinearSolver {
            config,
            solver,
            callbacks: HashMap::new(),
            memory_usage: 0,
        })
    }

    #[wasm_bindgen]
    pub fn solve(
        &mut self,
        matrix_data: &[f64],
        matrix_rows: usize,
        matrix_cols: usize,
        vector_data: &[f64],
    ) -> Result<Vec<f64>, JsValue> {
        // Validate input dimensions
        if matrix_data.len() != matrix_rows * matrix_cols {
            return Err(JsValue::from_str("Matrix dimensions mismatch"));
        }

        if vector_data.len() != matrix_rows {
            return Err(JsValue::from_str("Vector size mismatch"));
        }

        // Create matrix and vector views
        let matrix = Matrix::from_slice(matrix_data, matrix_rows, matrix_cols);
        let vector = Vector::from_slice(vector_data);

        // Update memory usage tracking
        self.memory_usage = matrix_data.len() * 8 + vector_data.len() * 8;

        // Solve system
        match self.solver.solve(&matrix, &vector) {
            Ok(solution) => Ok(solution.data().to_vec()),
            Err(e) => Err(JsValue::from_str(&format!("Solver error: {}", e))),
        }
    }

    #[wasm_bindgen]
    pub fn solve_stream(
        &mut self,
        matrix_data: &[f64],
        matrix_rows: usize,
        matrix_cols: usize,
        vector_data: &[f64],
        progress_callback: &js_sys::Function,
    ) -> Result<Vec<f64>, JsValue> {
        // Validate input
        if matrix_data.len() != matrix_rows * matrix_cols {
            return Err(JsValue::from_str("Matrix dimensions mismatch"));
        }

        if vector_data.len() != matrix_rows {
            return Err(JsValue::from_str("Vector size mismatch"));
        }

        let matrix = Matrix::from_slice(matrix_data, matrix_rows, matrix_cols);
        let vector = Vector::from_slice(vector_data);

        self.memory_usage = matrix_data.len() * 8 + vector_data.len() * 8;

        // Solve with streaming callback
        let mut solution = None;
        let chunk_size = self.config.stream_chunk_size;

        let result = self
            .solver
            .solve_with_callback(&matrix, &vector, chunk_size, |step_data| {
                let timestamp = js_sys::Date::now();
                let step = SolutionStep {
                    iteration: step_data.iteration,
                    residual: step_data.residual,
                    timestamp,
                    convergence: step_data.converged,
                };

                let step_js = serde_wasm_bindgen::to_value(&step).unwrap();
                let _ = progress_callback.call1(&JsValue::NULL, &step_js);

                if step_data.converged {
                    solution = Some(step_data.solution.clone());
                }
            });

        match result {
            Ok(final_solution) => Ok(final_solution.data().to_vec()),
            Err(e) => Err(JsValue::from_str(&format!("Streaming solver error: {}", e))),
        }
    }

    #[wasm_bindgen]
    pub fn solve_batch(&mut self, batch_data: JsValue) -> Result<JsValue, JsValue> {
        #[derive(Deserialize)]
        struct BatchRequest {
            id: String,
            matrix_data: Vec<f64>,
            matrix_rows: usize,
            matrix_cols: usize,
            vector_data: Vec<f64>,
        }

        #[derive(Serialize)]
        struct BatchResult {
            id: String,
            solution: Vec<f64>,
            iterations: usize,
            error: Option<String>,
        }

        let requests: Vec<BatchRequest> = serde_wasm_bindgen::from_value(batch_data)
            .map_err(|e| JsValue::from_str(&format!("Invalid batch data: {}", e)))?;

        let mut results = Vec::new();

        for request in requests {
            let result = match self.solve(
                &request.matrix_data,
                request.matrix_rows,
                request.matrix_cols,
                &request.vector_data,
            ) {
                Ok(solution) => BatchResult {
                    id: request.id,
                    solution,
                    iterations: self.solver.get_last_iteration_count(),
                    error: None,
                },
                Err(e) => BatchResult {
                    id: request.id,
                    solution: Vec::new(),
                    iterations: 0,
                    error: Some(format!("{:?}", e)),
                },
            };
            results.push(result);
        }

        serde_wasm_bindgen::to_value(&results)
            .map_err(|e| JsValue::from_str(&format!("Failed to serialize results: {}", e)))
    }

    #[wasm_bindgen(getter)]
    pub fn memory_usage(&self) -> JsValue {
        let usage = MemoryUsage {
            used: self.memory_usage,
            capacity: self.memory_usage * 2, // Rough estimate
        };

        serde_wasm_bindgen::to_value(&usage).unwrap()
    }

    #[wasm_bindgen]
    pub fn get_config(&self) -> JsValue {
        serde_wasm_bindgen::to_value(&self.config).unwrap()
    }

    #[wasm_bindgen]
    pub fn dispose(&mut self) {
        self.callbacks.clear();
        self.memory_usage = 0;
    }
}

// Zero-copy matrix view for efficient data transfer
#[wasm_bindgen]
pub struct MatrixView {
    data: Vec<f64>,
    rows: usize,
    cols: usize,
}

#[wasm_bindgen]
impl MatrixView {
    #[wasm_bindgen(constructor)]
    pub fn new(rows: usize, cols: usize) -> MatrixView {
        MatrixView {
            data: vec![0.0; rows * cols],
            rows,
            cols,
        }
    }

    #[wasm_bindgen(getter)]
    pub fn data(&self) -> *const f64 {
        self.data.as_ptr()
    }

    #[wasm_bindgen(getter)]
    pub fn length(&self) -> usize {
        self.data.len()
    }

    #[wasm_bindgen(getter)]
    pub fn rows(&self) -> usize {
        self.rows
    }

    #[wasm_bindgen(getter)]
    pub fn cols(&self) -> usize {
        self.cols
    }

    // Zero-copy access to data
    #[wasm_bindgen]
    pub fn data_view(&self) -> js_sys::Float64Array {
        unsafe { js_sys::Float64Array::view(&self.data) }
    }

    // Set data without copying
    #[wasm_bindgen]
    pub fn set_data(&mut self, data: &[f64]) -> Result<(), JsValue> {
        if data.len() != self.data.len() {
            return Err(JsValue::from_str("Data length mismatch"));
        }
        self.data.copy_from_slice(data);
        Ok(())
    }

    #[wasm_bindgen]
    pub fn get_element(&self, row: usize, col: usize) -> Result<f64, JsValue> {
        if row >= self.rows || col >= self.cols {
            return Err(JsValue::from_str("Index out of bounds"));
        }
        Ok(self.data[row * self.cols + col])
    }

    #[wasm_bindgen]
    pub fn set_element(&mut self, row: usize, col: usize, value: f64) -> Result<(), JsValue> {
        if row >= self.rows || col >= self.cols {
            return Err(JsValue::from_str("Index out of bounds"));
        }
        self.data[row * self.cols + col] = value;
        Ok(())
    }
}

// Memory management utilities
#[wasm_bindgen]
pub fn allocate_matrix(rows: usize, cols: usize) -> *mut f64 {
    // Guard against integer overflow before calling Layout::array.
    // `rows * cols` on a 32-bit WASM target wraps silently, so use
    // checked_mul and return a null pointer on overflow rather than
    // allocating the wrong-sized buffer or panicking.
    let size = match rows.checked_mul(cols) {
        Some(s) => s,
        None => return core::ptr::null_mut(),
    };
    let layout = match std::alloc::Layout::array::<f64>(size) {
        Ok(l) => l,
        Err(_) => return core::ptr::null_mut(),
    };
    // SAFETY: `layout` has non-zero size (rows * cols > 0 after the
    // overflow check above) and correct alignment for f64.  The caller
    // is responsible for freeing this pointer via `deallocate_matrix`
    // with the same `rows` and `cols` values.
    unsafe { std::alloc::alloc(layout) as *mut f64 }
}

#[wasm_bindgen]
pub fn deallocate_matrix(ptr: *mut f64, rows: usize, cols: usize) {
    if ptr.is_null() {
        return;
    }
    let size = match rows.checked_mul(cols) {
        Some(s) => s,
        None => return,
    };
    let layout = match std::alloc::Layout::array::<f64>(size) {
        Ok(l) => l,
        Err(_) => return,
    };
    // SAFETY: `ptr` was obtained from `allocate_matrix` with the same
    // `rows` and `cols`, so the layout matches and the memory is still
    // valid.  The null-pointer guard above makes the deref safe.
    unsafe { std::alloc::dealloc(ptr as *mut u8, layout) }
}

// Utility functions for performance benchmarking
#[wasm_bindgen]
pub fn benchmark_matrix_multiply(size: usize) -> f64 {
    let start = js_sys::Date::now();

    let matrix_a = Matrix::identity(size);
    let matrix_b = Matrix::identity(size);
    let _result = matrix_a.multiply(&matrix_b);

    js_sys::Date::now() - start
}

#[wasm_bindgen]
pub fn get_wasm_memory_usage() -> usize {
    // Return current memory usage in bytes
    #[cfg(target_arch = "wasm32")]
    {
        use core::arch::wasm32;
        unsafe {
            wasm32::memory_size(0) * 65536 // Pages to bytes
        }
    }
    #[cfg(not(target_arch = "wasm32"))]
    {
        0
    }
}

/// Get available features in this WASM build
#[wasm_bindgen]
pub fn get_features() -> JsValue {
    let mut features = Vec::new();

    #[cfg(feature = "simd")]
    features.push("simd");

    #[cfg(feature = "wasm")]
    features.push("wasm");

    #[cfg(feature = "std")]
    features.push("std");

    serde_wasm_bindgen::to_value(&features).unwrap()
}

/// Check if SIMD is enabled and supported
#[wasm_bindgen]
pub fn enable_simd() -> bool {
    #[cfg(feature = "simd")]
    {
        #[cfg(target_arch = "wasm32")]
        {
            // Check for WASM SIMD support
            cfg!(target_feature = "simd128")
        }
        #[cfg(not(target_arch = "wasm32"))]
        {
            crate::has_simd_support()
        }
    }
    #[cfg(not(feature = "simd"))]
    {
        false
    }
}