use crate::array::Array;
use crate::error::Result as NumRs2Result;
use crate::stats::Statistics;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct WasmArray {
inner: Array<f64>,
}
#[wasm_bindgen]
impl WasmArray {
#[wasm_bindgen]
pub fn zeros(shape: &[usize]) -> WasmArray {
WasmArray {
inner: Array::zeros(shape),
}
}
#[wasm_bindgen]
pub fn ones(shape: &[usize]) -> WasmArray {
WasmArray {
inner: Array::ones(shape),
}
}
#[wasm_bindgen]
pub fn full(shape: &[usize], value: f64) -> WasmArray {
WasmArray {
inner: Array::full(shape, value),
}
}
#[wasm_bindgen]
pub fn from_vec(data: &[f64], shape: &[usize]) -> Result<WasmArray, JsValue> {
let total_size: usize = shape.iter().product();
if data.len() != total_size {
return Err(JsValue::from_str(&format!(
"Data length {} does not match shape product {}",
data.len(),
total_size
)));
}
Ok(WasmArray {
inner: Array::from_vec(data.to_vec()).reshape(shape),
})
}
#[wasm_bindgen]
pub fn shape(&self) -> Vec<usize> {
self.inner.shape()
}
#[wasm_bindgen]
pub fn ndim(&self) -> usize {
self.inner.ndim()
}
#[wasm_bindgen]
pub fn size(&self) -> usize {
self.inner.size()
}
#[wasm_bindgen]
pub fn reshape(&self, new_shape: &[usize]) -> Result<WasmArray, JsValue> {
let new_size: usize = new_shape.iter().product();
if new_size != self.inner.size() {
return Err(JsValue::from_str(&format!(
"Cannot reshape array of size {} into shape with size {}",
self.inner.size(),
new_size
)));
}
Ok(WasmArray {
inner: self.inner.reshape(new_shape),
})
}
#[wasm_bindgen]
pub fn transpose(&self) -> WasmArray {
WasmArray {
inner: self.inner.transpose(),
}
}
#[wasm_bindgen]
pub fn get(&self, indices: &[usize]) -> Result<f64, JsValue> {
self.inner
.get(indices)
.map_err(|e| JsValue::from_str(&format!("Get error: {}", e)))
}
#[wasm_bindgen]
pub fn set(&mut self, indices: &[usize], value: f64) -> Result<(), JsValue> {
self.inner
.set(indices, value)
.map_err(|e| JsValue::from_str(&format!("Set error: {}", e)))
}
#[wasm_bindgen]
pub fn to_vec(&self) -> Vec<f64> {
self.inner.to_vec()
}
#[wasm_bindgen]
pub fn add(&self, other: &WasmArray) -> Result<WasmArray, JsValue> {
if self.inner.shape() != other.inner.shape() {
return Err(JsValue::from_str("Arrays must have the same shape"));
}
Ok(WasmArray {
inner: self.inner.add(&other.inner),
})
}
#[wasm_bindgen]
pub fn subtract(&self, other: &WasmArray) -> Result<WasmArray, JsValue> {
if self.inner.shape() != other.inner.shape() {
return Err(JsValue::from_str("Arrays must have the same shape"));
}
Ok(WasmArray {
inner: self.inner.subtract(&other.inner),
})
}
#[wasm_bindgen]
pub fn multiply(&self, other: &WasmArray) -> Result<WasmArray, JsValue> {
if self.inner.shape() != other.inner.shape() {
return Err(JsValue::from_str("Arrays must have the same shape"));
}
Ok(WasmArray {
inner: self.inner.multiply(&other.inner),
})
}
#[wasm_bindgen]
pub fn divide(&self, other: &WasmArray) -> Result<WasmArray, JsValue> {
if self.inner.shape() != other.inner.shape() {
return Err(JsValue::from_str("Arrays must have the same shape"));
}
Ok(WasmArray {
inner: self.inner.divide(&other.inner),
})
}
#[wasm_bindgen]
pub fn add_scalar(&self, scalar: f64) -> WasmArray {
WasmArray {
inner: self.inner.add_scalar(scalar),
}
}
#[wasm_bindgen]
pub fn multiply_scalar(&self, scalar: f64) -> WasmArray {
WasmArray {
inner: self.inner.multiply_scalar(scalar),
}
}
#[wasm_bindgen]
pub fn sum(&self) -> f64 {
self.inner.sum()
}
#[wasm_bindgen]
pub fn mean(&self) -> f64 {
self.inner.mean()
}
#[wasm_bindgen]
pub fn min(&self) -> f64 {
self.inner.min()
}
#[wasm_bindgen]
pub fn max(&self) -> f64 {
self.inner.max()
}
}
impl WasmArray {
pub(crate) fn from_array(array: Array<f64>) -> WasmArray {
WasmArray { inner: array }
}
pub(crate) fn into_inner(self) -> Array<f64> {
self.inner
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zeros() {
let arr = WasmArray::zeros(&[2, 3]);
assert_eq!(arr.shape(), vec![2, 3]);
assert_eq!(arr.size(), 6);
}
#[test]
fn test_ones() {
let arr = WasmArray::ones(&[2, 3]);
assert_eq!(arr.sum(), 6.0);
}
#[test]
fn test_from_vec() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let arr = WasmArray::from_vec(&data, &[2, 3]).expect("from_vec should succeed");
assert_eq!(arr.shape(), vec![2, 3]);
assert_eq!(arr.to_vec(), data);
}
#[test]
fn test_reshape() {
let arr = WasmArray::zeros(&[2, 3]);
let reshaped = arr.reshape(&[3, 2]).expect("reshape should succeed");
assert_eq!(reshaped.shape(), vec![3, 2]);
}
#[test]
fn test_transpose() {
let arr = WasmArray::zeros(&[2, 3]);
let t = arr.transpose();
assert_eq!(t.shape(), vec![3, 2]);
}
#[test]
fn test_arithmetic() {
let a = WasmArray::ones(&[2, 3]);
let b = WasmArray::full(&[2, 3], 2.0);
let sum = a.add(&b).expect("add should succeed");
assert_eq!(sum.sum(), 18.0);
let diff = b.subtract(&a).expect("subtract should succeed");
assert_eq!(diff.sum(), 6.0);
let prod = a.multiply(&b).expect("multiply should succeed");
assert_eq!(prod.sum(), 12.0);
let quot = b.divide(&a).expect("divide should succeed");
assert_eq!(quot.sum(), 12.0); }
#[test]
fn test_scalar_ops() {
let arr = WasmArray::ones(&[2, 3]);
let added = arr.add_scalar(5.0);
assert_eq!(added.sum(), 36.0);
let scaled = arr.multiply_scalar(3.0);
assert_eq!(scaled.sum(), 18.0); }
}