1use crate::error::WasmError;
4use crate::utils::{js_array_to_vec_f64, parse_shape, typed_array_to_vec_f64};
5use scirs2_core::ndarray::{Array1, ArrayD};
6use wasm_bindgen::prelude::*;
7
8#[wasm_bindgen]
10pub struct WasmArray {
11 data: ArrayD<f64>,
12}
13
14impl WasmArray {
16 pub(crate) fn from_array(data: ArrayD<f64>) -> Self {
17 Self { data }
18 }
19
20 pub(crate) fn data(&self) -> &ArrayD<f64> {
21 &self.data
22 }
23}
24
25#[wasm_bindgen]
26impl WasmArray {
27 #[wasm_bindgen(constructor)]
29 pub fn new(data: &JsValue) -> Result<WasmArray, JsValue> {
30 let vec = if data.is_array() {
31 let array = js_sys::Array::from(data);
32 js_array_to_vec_f64(&array)?
33 } else {
34 typed_array_to_vec_f64(data)?
35 };
36
37 let array = Array1::from_vec(vec).into_dyn();
38 Ok(WasmArray { data: array })
39 }
40
41 #[wasm_bindgen]
43 pub fn from_shape(shape: &JsValue, data: &JsValue) -> Result<WasmArray, JsValue> {
44 let shape_vec = parse_shape(shape)?;
45 let data_vec = if data.is_array() {
46 let array = js_sys::Array::from(data);
47 js_array_to_vec_f64(&array)?
48 } else {
49 typed_array_to_vec_f64(data)?
50 };
51
52 let total_size: usize = shape_vec.iter().product();
53 if data_vec.len() != total_size {
54 return Err(WasmError::ShapeMismatch {
55 expected: vec![total_size],
56 actual: vec![data_vec.len()],
57 }
58 .into());
59 }
60
61 let array = ArrayD::from_shape_vec(shape_vec, data_vec)
62 .map_err(|e: ndarray::ShapeError| WasmError::InvalidDimensions(e.to_string()))?;
63
64 Ok(WasmArray { data: array })
65 }
66
67 #[wasm_bindgen]
69 pub fn zeros(shape: &JsValue) -> Result<WasmArray, JsValue> {
70 let shape_vec = parse_shape(shape)?;
71 let array = ArrayD::zeros(shape_vec);
72 Ok(WasmArray { data: array })
73 }
74
75 #[wasm_bindgen]
77 pub fn ones(shape: &JsValue) -> Result<WasmArray, JsValue> {
78 let shape_vec = parse_shape(shape)?;
79 let array = ArrayD::ones(shape_vec);
80 Ok(WasmArray { data: array })
81 }
82
83 #[wasm_bindgen]
85 pub fn full(shape: &JsValue, value: f64) -> Result<WasmArray, JsValue> {
86 let shape_vec = parse_shape(shape)?;
87 let array = ArrayD::from_elem(shape_vec, value);
88 Ok(WasmArray { data: array })
89 }
90
91 #[wasm_bindgen]
93 pub fn linspace(start: f64, end: f64, num: usize) -> Result<WasmArray, JsValue> {
94 if num == 0 {
95 return Err(WasmError::InvalidParameter("num must be > 0".to_string()).into());
96 }
97
98 let step = if num > 1 {
99 (end - start) / (num - 1) as f64
100 } else {
101 0.0
102 };
103
104 let vec: Vec<f64> = (0..num).map(|i| start + i as f64 * step).collect();
105
106 let array = Array1::from_vec(vec).into_dyn();
107 Ok(WasmArray { data: array })
108 }
109
110 #[wasm_bindgen]
112 pub fn arange(start: f64, end: f64, step: f64) -> Result<WasmArray, JsValue> {
113 if step == 0.0 {
114 return Err(WasmError::InvalidParameter("step cannot be zero".to_string()).into());
115 }
116
117 if (end - start).signum() != step.signum() {
118 return Err(WasmError::InvalidParameter(
119 "step direction does not match range".to_string(),
120 )
121 .into());
122 }
123
124 let num = ((end - start) / step).abs().ceil() as usize;
125 let vec: Vec<f64> = (0..num).map(|i| start + i as f64 * step).collect();
126
127 let array = Array1::from_vec(vec).into_dyn();
128 Ok(WasmArray { data: array })
129 }
130
131 #[wasm_bindgen]
133 pub fn shape(&self) -> js_sys::Array {
134 let shape = self.data.shape();
135 let array = js_sys::Array::new_with_length(shape.len() as u32);
136
137 for (i, &dim) in shape.iter().enumerate() {
138 array.set(i as u32, JsValue::from_f64(dim as f64));
139 }
140
141 array
142 }
143
144 #[wasm_bindgen]
146 pub fn ndim(&self) -> usize {
147 self.data.ndim()
148 }
149
150 #[wasm_bindgen]
152 pub fn len(&self) -> usize {
153 self.data.len()
154 }
155
156 #[wasm_bindgen]
158 pub fn is_empty(&self) -> bool {
159 self.data.is_empty()
160 }
161
162 #[wasm_bindgen]
164 pub fn to_array(&self) -> js_sys::Float64Array {
165 let vec: Vec<f64> = self.data.iter().copied().collect();
166 let array = js_sys::Float64Array::new_with_length(vec.len() as u32);
167 array.copy_from(&vec);
168 array
169 }
170
171 #[wasm_bindgen]
173 pub fn to_nested_array(&self) -> JsValue {
174 let vec: Vec<f64> = self.data.iter().copied().collect();
177 serde_wasm_bindgen::to_value(&vec).unwrap_or(JsValue::NULL)
178 }
179
180 #[wasm_bindgen]
182 pub fn get(&self, index: usize) -> Result<f64, JsValue> {
183 self.data
184 .as_slice()
185 .and_then(|s| s.get(index).copied())
186 .ok_or_else(|| {
187 WasmError::IndexOutOfBounds(format!(
188 "Index {} out of bounds for array of length {}",
189 index,
190 self.len()
191 ))
192 .into()
193 })
194 }
195
196 #[wasm_bindgen]
198 pub fn set(&mut self, index: usize, value: f64) -> Result<(), JsValue> {
199 self.data
200 .as_slice_mut()
201 .and_then(|s| s.get_mut(index))
202 .map(|v| *v = value)
203 .ok_or_else(|| {
204 WasmError::IndexOutOfBounds(format!(
205 "Index {} out of bounds for array of length {}",
206 index,
207 self.len()
208 ))
209 .into()
210 })
211 }
212
213 #[wasm_bindgen]
215 pub fn reshape(&self, new_shape: &JsValue) -> Result<WasmArray, JsValue> {
216 let shape_vec = parse_shape(new_shape)?;
217 let total_size: usize = shape_vec.iter().product();
218
219 if total_size != self.len() {
220 return Err(WasmError::ShapeMismatch {
221 expected: vec![self.len()],
222 actual: vec![total_size],
223 }
224 .into());
225 }
226
227 let vec: Vec<f64> = self.data.iter().copied().collect();
228 let array = ArrayD::from_shape_vec(shape_vec, vec)
229 .map_err(|e: ndarray::ShapeError| WasmError::InvalidDimensions(e.to_string()))?;
230
231 Ok(WasmArray { data: array })
232 }
233
234 #[wasm_bindgen]
236 pub fn transpose(&self) -> Result<WasmArray, JsValue> {
237 if self.ndim() != 2 {
238 return Err(WasmError::InvalidDimensions(
239 "Transpose is only supported for 2D arrays".to_string(),
240 )
241 .into());
242 }
243
244 let transposed = self
245 .data
246 .clone()
247 .into_dimensionality::<ndarray::Ix2>()
248 .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?
249 .t()
250 .to_owned()
251 .into_dyn();
252
253 Ok(WasmArray { data: transposed })
254 }
255
256 #[allow(clippy::should_implement_trait)]
258 #[wasm_bindgen]
259 pub fn clone(&self) -> WasmArray {
260 WasmArray {
261 data: self.data.clone(),
262 }
263 }
264}
265
266#[wasm_bindgen]
268pub fn add(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
269 if a.data().shape() != b.data().shape() {
270 return Err(WasmError::ShapeMismatch {
271 expected: a.data().shape().to_vec(),
272 actual: b.data().shape().to_vec(),
273 }
274 .into());
275 }
276
277 Ok(WasmArray {
278 data: a.data() + b.data(),
279 })
280}
281
282#[wasm_bindgen]
284pub fn subtract(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
285 if a.data().shape() != b.data().shape() {
286 return Err(WasmError::ShapeMismatch {
287 expected: a.data().shape().to_vec(),
288 actual: b.data().shape().to_vec(),
289 }
290 .into());
291 }
292
293 Ok(WasmArray {
294 data: a.data() - b.data(),
295 })
296}
297
298#[wasm_bindgen]
300pub fn multiply(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
301 if a.data().shape() != b.data().shape() {
302 return Err(WasmError::ShapeMismatch {
303 expected: a.data().shape().to_vec(),
304 actual: b.data().shape().to_vec(),
305 }
306 .into());
307 }
308
309 Ok(WasmArray {
310 data: a.data() * b.data(),
311 })
312}
313
314#[wasm_bindgen]
316pub fn divide(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
317 if a.data().shape() != b.data().shape() {
318 return Err(WasmError::ShapeMismatch {
319 expected: a.data().shape().to_vec(),
320 actual: b.data().shape().to_vec(),
321 }
322 .into());
323 }
324
325 Ok(WasmArray {
326 data: a.data() / b.data(),
327 })
328}
329
330#[wasm_bindgen]
332pub fn dot(a: &WasmArray, b: &WasmArray) -> Result<WasmArray, JsValue> {
333 match (a.ndim(), b.ndim()) {
334 (1, 1) => {
335 let a1 = a
337 .data()
338 .clone()
339 .into_dimensionality::<ndarray::Ix1>()
340 .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?;
341 let b1 = b
342 .data()
343 .clone()
344 .into_dimensionality::<ndarray::Ix1>()
345 .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?;
346
347 let result = a1.dot(&b1);
348 let array = ArrayD::from_elem(vec![], result);
349 Ok(WasmArray { data: array })
350 }
351 (2, 2) => {
352 let a2 = a
354 .data()
355 .clone()
356 .into_dimensionality::<ndarray::Ix2>()
357 .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?;
358 let b2 = b
359 .data()
360 .clone()
361 .into_dimensionality::<ndarray::Ix2>()
362 .map_err(|e: ndarray::ShapeError| WasmError::ComputationError(e.to_string()))?;
363
364 if a2.ncols() != b2.nrows() {
365 return Err(WasmError::ShapeMismatch {
366 expected: vec![a2.nrows(), b2.ncols()],
367 actual: vec![a2.nrows(), a2.ncols(), b2.nrows(), b2.ncols()],
368 }
369 .into());
370 }
371
372 let result = a2.dot(&b2).into_dyn();
373 Ok(WasmArray { data: result })
374 }
375 _ => Err(WasmError::InvalidDimensions(
376 "dot only supports 1D-1D or 2D-2D arrays".to_string(),
377 )
378 .into()),
379 }
380}
381
382#[wasm_bindgen]
384pub fn sum(arr: &WasmArray) -> f64 {
385 arr.data().sum()
386}
387
388#[wasm_bindgen]
390pub fn mean(arr: &WasmArray) -> f64 {
391 if arr.is_empty() {
392 return f64::NAN;
393 }
394 arr.data().sum() / arr.len() as f64
395}
396
397#[wasm_bindgen]
399pub fn min(arr: &WasmArray) -> f64 {
400 arr.data().iter().copied().fold(f64::INFINITY, f64::min)
401}
402
403#[wasm_bindgen]
405pub fn max(arr: &WasmArray) -> f64 {
406 arr.data().iter().copied().fold(f64::NEG_INFINITY, f64::max)
407}