Skip to main content

scirs2_wasm/
array.rs

1//! N-dimensional array operations for WASM
2
3use crate::error::WasmError;
4use crate::utils::{js_array_to_vec_f64, parse_shape, typed_array_to_vec_f64};
5use scirs2_core::ndarray::{Array1, ArrayD};
6use wasm_bindgen::prelude::*;
7
8/// A wrapper around ndarray for JavaScript interop
9#[wasm_bindgen]
10pub struct WasmArray {
11    data: ArrayD<f64>,
12}
13
14// Internal implementation for accessing data
15impl WasmArray {
16    pub(crate) fn from_array(data: ArrayD<f64>) -> Self {
17        Self { data }
18    }
19
20    pub(crate) fn data(&self) -> &ArrayD<f64> {
21        &self.data
22    }
23}
24
25#[wasm_bindgen]
26impl WasmArray {
27    /// Create a 1D array from a JavaScript array or typed array
28    #[wasm_bindgen(constructor)]
29    pub fn new(data: &JsValue) -> Result<WasmArray, JsValue> {
30        let vec = if data.is_array() {
31            let array = js_sys::Array::from(data);
32            js_array_to_vec_f64(&array)?
33        } else {
34            typed_array_to_vec_f64(data)?
35        };
36
37        let array = Array1::from_vec(vec).into_dyn();
38        Ok(WasmArray { data: array })
39    }
40
41    /// Create an array with a specific shape from flat data
42    #[wasm_bindgen]
43    pub fn from_shape(shape: &JsValue, data: &JsValue) -> Result<WasmArray, JsValue> {
44        let shape_vec = parse_shape(shape)?;
45        let data_vec = if data.is_array() {
46            let array = js_sys::Array::from(data);
47            js_array_to_vec_f64(&array)?
48        } else {
49            typed_array_to_vec_f64(data)?
50        };
51
52        let total_size: usize = shape_vec.iter().product();
53        if data_vec.len() != total_size {
54            return Err(WasmError::ShapeMismatch {
55                expected: vec![total_size],
56                actual: vec![data_vec.len()],
57            }
58            .into());
59        }
60
61        let array = ArrayD::from_shape_vec(shape_vec, data_vec)
62            .map_err(|e: ndarray::ShapeError| WasmError::InvalidDimensions(e.to_string()))?;
63
64        Ok(WasmArray { data: array })
65    }
66
67    /// Create an array of zeros with the given shape
68    #[wasm_bindgen]
69    pub fn zeros(shape: &JsValue) -> Result<WasmArray, JsValue> {
70        let shape_vec = parse_shape(shape)?;
71        let array = ArrayD::zeros(shape_vec);
72        Ok(WasmArray { data: array })
73    }
74
75    /// Create an array of ones with the given shape
76    #[wasm_bindgen]
77    pub fn ones(shape: &JsValue) -> Result<WasmArray, JsValue> {
78        let shape_vec = parse_shape(shape)?;
79        let array = ArrayD::ones(shape_vec);
80        Ok(WasmArray { data: array })
81    }
82
83    /// Create an array filled with a constant value
84    #[wasm_bindgen]
85    pub fn full(shape: &JsValue, value: f64) -> Result<WasmArray, JsValue> {
86        let shape_vec = parse_shape(shape)?;
87        let array = ArrayD::from_elem(shape_vec, value);
88        Ok(WasmArray { data: array })
89    }
90
91    /// Create an evenly spaced array (like numpy.linspace)
92    #[wasm_bindgen]
93    pub fn linspace(start: f64, end: f64, num: usize) -> Result<WasmArray, JsValue> {
94        if num == 0 {
95            return Err(WasmError::InvalidParameter("num must be > 0".to_string()).into());
96        }
97
98        let step = if num > 1 {
99            (end - start) / (num - 1) as f64
100        } else {
101            0.0
102        };
103
104        let vec: Vec<f64> = (0..num).map(|i| start + i as f64 * step).collect();
105
106        let array = Array1::from_vec(vec).into_dyn();
107        Ok(WasmArray { data: array })
108    }
109
110    /// Create an array with evenly spaced values (like numpy.arange)
111    #[wasm_bindgen]
112    pub fn arange(start: f64, end: f64, step: f64) -> Result<WasmArray, JsValue> {
113        if step == 0.0 {
114            return Err(WasmError::InvalidParameter("step cannot be zero".to_string()).into());
115        }
116
117        if (end - start).signum() != step.signum() {
118            return Err(WasmError::InvalidParameter(
119                "step direction does not match range".to_string(),
120            )
121            .into());
122        }
123
124        let num = ((end - start) / step).abs().ceil() as usize;
125        let vec: Vec<f64> = (0..num).map(|i| start + i as f64 * step).collect();
126
127        let array = Array1::from_vec(vec).into_dyn();
128        Ok(WasmArray { data: array })
129    }
130
131    /// Get the shape of the array
132    #[wasm_bindgen]
133    pub fn shape(&self) -> js_sys::Array {
134        let shape = self.data.shape();
135        let array = js_sys::Array::new_with_length(shape.len() as u32);
136
137        for (i, &dim) in shape.iter().enumerate() {
138            array.set(i as u32, JsValue::from_f64(dim as f64));
139        }
140
141        array
142    }
143
144    /// Get the number of dimensions
145    #[wasm_bindgen]
146    pub fn ndim(&self) -> usize {
147        self.data.ndim()
148    }
149
150    /// Get the total number of elements
151    #[wasm_bindgen]
152    pub fn len(&self) -> usize {
153        self.data.len()
154    }
155
156    /// Check if the array is empty
157    #[wasm_bindgen]
158    pub fn is_empty(&self) -> bool {
159        self.data.is_empty()
160    }
161
162    /// Convert to a flat JavaScript Float64Array
163    #[wasm_bindgen]
164    pub fn to_array(&self) -> js_sys::Float64Array {
165        let vec: Vec<f64> = self.data.iter().copied().collect();
166        let array = js_sys::Float64Array::new_with_length(vec.len() as u32);
167        array.copy_from(&vec);
168        array
169    }
170
171    /// Convert to a JavaScript array (nested for multi-dimensional arrays)
172    #[wasm_bindgen]
173    pub fn to_nested_array(&self) -> JsValue {
174        // For simplicity, return flat array with shape info
175        // In production, implement proper nested array conversion
176        let vec: Vec<f64> = self.data.iter().copied().collect();
177        serde_wasm_bindgen::to_value(&vec).unwrap_or(JsValue::NULL)
178    }
179
180    /// Get a value at the specified index (flat indexing)
181    #[wasm_bindgen]
182    pub fn get(&self, index: usize) -> Result<f64, JsValue> {
183        self.data
184            .as_slice()
185            .and_then(|s| s.get(index).copied())
186            .ok_or_else(|| {
187                WasmError::IndexOutOfBounds(format!(
188                    "Index {} out of bounds for array of length {}",
189                    index,
190                    self.len()
191                ))
192                .into()
193            })
194    }
195
196    /// Set a value at the specified index (flat indexing)
197    #[wasm_bindgen]
198    pub fn set(&mut self, index: usize, value: f64) -> Result<(), JsValue> {
199        self.data
200            .as_slice_mut()
201            .and_then(|s| s.get_mut(index))
202            .map(|v| *v = value)
203            .ok_or_else(|| {
204                WasmError::IndexOutOfBounds(format!(
205                    "Index {} out of bounds for array of length {}",
206                    index,
207                    self.len()
208                ))
209                .into()
210            })
211    }
212
213    /// Reshape the array
214    #[wasm_bindgen]
215    pub fn reshape(&self, new_shape: &JsValue) -> Result<WasmArray, JsValue> {
216        let shape_vec = parse_shape(new_shape)?;
217        let total_size: usize = shape_vec.iter().product();
218
219        if total_size != self.len() {
220            return Err(WasmError::ShapeMismatch {
221                expected: vec![self.len()],
222                actual: vec![total_size],
223            }
224            .into());
225        }
226
227        let vec: Vec<f64> = self.data.iter().copied().collect();
228        let array = ArrayD::from_shape_vec(shape_vec, vec)
229            .map_err(|e: ndarray::ShapeError| WasmError::InvalidDimensions(e.to_string()))?;
230
231        Ok(WasmArray { data: array })
232    }
233
234    /// Transpose the array (2D only for now)
235    #[wasm_bindgen]
236    pub fn transpose(&self) -> Result<WasmArray, JsValue> {
237        if self.ndim() != 2 {
238            return Err(WasmError::InvalidDimensions(
239                "Transpose is only supported for 2D arrays".to_string(),
240            )
241            .into());
242        }
243
244        let transposed = self
245            .data
246            .clone()
247            .into_dimensionality::<ndarray::Ix2>()
248            .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?
249            .t()
250            .to_owned()
251            .into_dyn();
252
253        Ok(WasmArray { data: transposed })
254    }
255
256    /// Clone the array
257    #[allow(clippy::should_implement_trait)]
258    #[wasm_bindgen]
259    pub fn clone(&self) -> WasmArray {
260        WasmArray {
261            data: self.data.clone(),
262        }
263    }
264}
265
266/// Add two arrays element-wise
267#[wasm_bindgen]
268pub fn add(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
269    if a.data().shape() != b.data().shape() {
270        return Err(WasmError::ShapeMismatch {
271            expected: a.data().shape().to_vec(),
272            actual: b.data().shape().to_vec(),
273        }
274        .into());
275    }
276
277    Ok(WasmArray {
278        data: a.data() + b.data(),
279    })
280}
281
282/// Subtract two arrays element-wise
283#[wasm_bindgen]
284pub fn subtract(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
285    if a.data().shape() != b.data().shape() {
286        return Err(WasmError::ShapeMismatch {
287            expected: a.data().shape().to_vec(),
288            actual: b.data().shape().to_vec(),
289        }
290        .into());
291    }
292
293    Ok(WasmArray {
294        data: a.data() - b.data(),
295    })
296}
297
298/// Multiply two arrays element-wise
299#[wasm_bindgen]
300pub fn multiply(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
301    if a.data().shape() != b.data().shape() {
302        return Err(WasmError::ShapeMismatch {
303            expected: a.data().shape().to_vec(),
304            actual: b.data().shape().to_vec(),
305        }
306        .into());
307    }
308
309    Ok(WasmArray {
310        data: a.data() * b.data(),
311    })
312}
313
314/// Divide two arrays element-wise
315#[wasm_bindgen]
316pub fn divide(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
317    if a.data().shape() != b.data().shape() {
318        return Err(WasmError::ShapeMismatch {
319            expected: a.data().shape().to_vec(),
320            actual: b.data().shape().to_vec(),
321        }
322        .into());
323    }
324
325    Ok(WasmArray {
326        data: a.data() / b.data(),
327    })
328}
329
330/// Compute dot product (1D) or matrix multiplication (2D)
331#[wasm_bindgen]
332pub fn dot(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
333    match (a.ndim(), b.ndim()) {
334        (1, 1) => {
335            // 1D dot product
336            let a1 = a
337                .data()
338                .clone()
339                .into_dimensionality::<ndarray::Ix1>()
340                .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?;
341            let b1 = b
342                .data()
343                .clone()
344                .into_dimensionality::<ndarray::Ix1>()
345                .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?;
346
347            let result = a1.dot(&b1);
348            let array = ArrayD::from_elem(vec![], result);
349            Ok(WasmArray { data: array })
350        }
351        (2, 2) => {
352            // Matrix multiplication
353            let a2 = a
354                .data()
355                .clone()
356                .into_dimensionality::<ndarray::Ix2>()
357                .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?;
358            let b2 = b
359                .data()
360                .clone()
361                .into_dimensionality::<ndarray::Ix2>()
362                .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?;
363
364            if a2.ncols() != b2.nrows() {
365                return Err(WasmError::ShapeMismatch {
366                    expected: vec![a2.nrows(), b2.ncols()],
367                    actual: vec![a2.nrows(), a2.ncols(), b2.nrows(), b2.ncols()],
368                }
369                .into());
370            }
371
372            let result = a2.dot(&b2).into_dyn();
373            Ok(WasmArray { data: result })
374        }
375        _ => Err(WasmError::InvalidDimensions(
376            "dot only supports 1D-1D or 2D-2D arrays".to_string(),
377        )
378        .into()),
379    }
380}
381
382/// Sum all elements in the array
383#[wasm_bindgen]
384pub fn sum(arr: &WasmArray) -> f64 {
385    arr.data().sum()
386}
387
388/// Compute the mean of all elements
389#[wasm_bindgen]
390pub fn mean(arr: &WasmArray) -> f64 {
391    if arr.is_empty() {
392        return f64::NAN;
393    }
394    arr.data().sum() / arr.len() as f64
395}
396
397/// Find the minimum value
398#[wasm_bindgen]
399pub fn min(arr: &WasmArray) -> f64 {
400    arr.data().iter().copied().fold(f64::INFINITY, f64::min)
401}
402
403/// Find the maximum value
404#[wasm_bindgen]
405pub fn max(arr: &WasmArray) -> f64 {
406    arr.data().iter().copied().fold(f64::NEG_INFINITY, f64::max)
407}