use wasm_bindgen::prelude::*;
use ghostflow_core::Tensor;
#[wasm_bindgen(start)]
pub fn init() {
console_error_panic_hook::set_once();
}
#[wasm_bindgen]
pub struct WasmTensor {
inner: Tensor,
}
#[wasm_bindgen]
impl WasmTensor {
#[wasm_bindgen(constructor)]
pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<WasmTensor, JsValue> {
let tensor = Tensor::from_slice(&data, &shape)
.map_err(|e| JsValue::from_str(&format!("Failed to create tensor: {:?}", e)))?;
Ok(WasmTensor { inner: tensor })
}
#[wasm_bindgen(js_name = zeros)]
pub fn zeros(shape: Vec<usize>) -> Result<WasmTensor, JsValue> {
let numel: usize = shape.iter().product();
let data = vec![0.0f32; numel];
Self::new(data, shape)
}
#[wasm_bindgen(js_name = ones)]
pub fn ones(shape: Vec<usize>) -> Result<WasmTensor, JsValue> {
let numel: usize = shape.iter().product();
let data = vec![1.0f32; numel];
Self::new(data, shape)
}
#[wasm_bindgen(js_name = shape)]
pub fn shape(&self) -> Vec<usize> {
self.inner.dims().to_vec()
}
#[wasm_bindgen(js_name = toArray)]
pub fn to_array(&self) -> Vec<f32> {
self.inner.data_f32()
}
#[wasm_bindgen(js_name = toString)]
pub fn to_string_js(&self) -> String {
format!("{:?}", self.inner)
}
}
#[wasm_bindgen]
pub fn version() -> String {
"0.5.0".to_string()
}
#[wasm_bindgen(js_name = matmul)]
pub fn matmul(a_data: Vec<f32>, a_shape: Vec<usize>, b_data: Vec<f32>, b_shape: Vec<usize>) -> Result<WasmTensor, JsValue> {
let a = Tensor::from_slice(&a_data, &a_shape)
.map_err(|e| JsValue::from_str(&format!("Failed to create tensor A: {:?}", e)))?;
let b = Tensor::from_slice(&b_data, &b_shape)
.map_err(|e| JsValue::from_str(&format!("Failed to create tensor B: {:?}", e)))?;
let result = a.matmul(&b)
.map_err(|e| JsValue::from_str(&format!("Matrix multiplication failed: {:?}", e)))?;
Ok(WasmTensor { inner: result })
}