1use wasm_bindgen::prelude::*;
6use ghostflow_core::Tensor;
7
8#[wasm_bindgen(start)]
10pub fn init() {
11 console_error_panic_hook::set_once();
12}
13
14#[wasm_bindgen]
16pub struct WasmTensor {
17 inner: Tensor,
18}
19
20#[wasm_bindgen]
21impl WasmTensor {
22 #[wasm_bindgen(constructor)]
24 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Result<WasmTensor, JsValue> {
25 let tensor = Tensor::from_slice(&data, &shape)
26 .map_err(|e| JsValue::from_str(&format!("Failed to create tensor: {:?}", e)))?;
27 Ok(WasmTensor { inner: tensor })
28 }
29
30 #[wasm_bindgen(js_name = zeros)]
32 pub fn zeros(shape: Vec<usize>) -> Result<WasmTensor, JsValue> {
33 let numel: usize = shape.iter().product();
34 let data = vec![0.0f32; numel];
35 Self::new(data, shape)
36 }
37
38 #[wasm_bindgen(js_name = ones)]
40 pub fn ones(shape: Vec<usize>) -> Result<WasmTensor, JsValue> {
41 let numel: usize = shape.iter().product();
42 let data = vec![1.0f32; numel];
43 Self::new(data, shape)
44 }
45
46 #[wasm_bindgen(js_name = shape)]
48 pub fn shape(&self) -> Vec<usize> {
49 self.inner.dims().to_vec()
50 }
51
52 #[wasm_bindgen(js_name = toArray)]
54 pub fn to_array(&self) -> Vec<f32> {
55 self.inner.data_f32()
56 }
57
58 #[wasm_bindgen(js_name = toString)]
60 pub fn to_string_js(&self) -> String {
61 format!("{:?}", self.inner)
62 }
63}
64
65#[wasm_bindgen]
67pub fn version() -> String {
68 "0.5.0".to_string()
69}
70
71#[wasm_bindgen(js_name = matmul)]
73pub fn matmul(a_data: Vec<f32>, a_shape: Vec<usize>, b_data: Vec<f32>, b_shape: Vec<usize>) -> Result<WasmTensor, JsValue> {
74 let a = Tensor::from_slice(&a_data, &a_shape)
75 .map_err(|e| JsValue::from_str(&format!("Failed to create tensor A: {:?}", e)))?;
76 let b = Tensor::from_slice(&b_data, &b_shape)
77 .map_err(|e| JsValue::from_str(&format!("Failed to create tensor B: {:?}", e)))?;
78
79 let result = a.matmul(&b)
80 .map_err(|e| JsValue::from_str(&format!("Matrix multiplication failed: {:?}", e)))?;
81
82 Ok(WasmTensor { inner: result })
83}