#[cfg(target_arch = "wasm32")]
use js_sys::{ArrayBuffer, Float32Array, Uint8Array, WebAssembly};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
#[cfg(target_arch = "wasm32")]
use web_sys::{console, window, Performance};
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
use std::arch::wasm32::*;
use crate::{DType, Device, Result, TensorError};
#[cfg(target_arch = "wasm32")]
pub struct WasmTensorOps {
memory: Option<WebAssembly::Memory>,
performance: Option<Performance>,
}
#[cfg(target_arch = "wasm32")]
impl WasmTensorOps {
pub fn new() -> Self {
let performance = window().and_then(|win| win.performance());
Self {
memory: None,
performance,
}
}
pub fn init_memory(&mut self, initial_pages: u32) -> Result<()> {
let memory_descriptor = js_sys::Object::new();
js_sys::Reflect::set(&memory_descriptor, &"initial".into(), &initial_pages.into())
.map_err(|_| TensorError::device_error_simple("Failed to set memory descriptor"))?;
let memory = WebAssembly::Memory::new(&memory_descriptor)
.map_err(|_| TensorError::device_error_simple("Failed to create WASM memory"))?;
self.memory = Some(memory);
Ok(())
}
pub fn now(&self) -> f64 {
self.performance.as_ref().map(|p| p.now()).unwrap_or(0.0)
}
pub fn log(&self, message: &str) {
console::log_1(&message.into());
}
pub fn create_float32_array(&self, data: &[f32]) -> Result<Float32Array> {
let array = Float32Array::new_with_length(data.len() as u32);
for (i, &value) in data.iter().enumerate() {
array.set_index(i as u32, value);
}
Ok(array)
}
pub fn from_float32_array(&self, array: &Float32Array) -> Result<Vec<f32>> {
let length = array.length() as usize;
let mut data = Vec::with_capacity(length);
for i in 0..length {
data.push(array.get_index(i as u32));
}
Ok(data)
}
pub fn add_simd(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TensorError::invalid_shape_simple(
"Array lengths must match",
));
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
self.add_simd_optimized(a, b, result)
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
self.add_scalar(a, b, result)
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn add_simd_optimized(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
let len = a.len();
let simd_len = (len / 4) * 4;
for i in (0..simd_len).step_by(4) {
unsafe {
let a_vec = v128_load(a.as_ptr().add(i) as *const v128);
let b_vec = v128_load(b.as_ptr().add(i) as *const v128);
let result_vec = f32x4_add(a_vec, b_vec);
v128_store(result.as_mut_ptr().add(i) as *mut v128, result_vec);
}
}
for i in simd_len..len {
result[i] = a[i] + b[i];
}
Ok(())
}
fn add_scalar(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
for (i, (a_val, b_val)) in a.iter().zip(b.iter()).enumerate() {
result[i] = a_val + b_val;
}
Ok(())
}
pub fn mul_simd(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TensorError::invalid_shape_simple(
"Array lengths must match",
));
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
self.mul_simd_optimized(a, b, result)
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
self.mul_scalar(a, b, result)
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn mul_simd_optimized(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
let len = a.len();
let simd_len = (len / 4) * 4;
for i in (0..simd_len).step_by(4) {
unsafe {
let a_vec = v128_load(a.as_ptr().add(i) as *const v128);
let b_vec = v128_load(b.as_ptr().add(i) as *const v128);
let result_vec = f32x4_mul(a_vec, b_vec);
v128_store(result.as_mut_ptr().add(i) as *mut v128, result_vec);
}
}
for i in simd_len..len {
result[i] = a[i] * b[i];
}
Ok(())
}
fn mul_scalar(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
for (i, (a_val, b_val)) in a.iter().zip(b.iter()).enumerate() {
result[i] = a_val * b_val;
}
Ok(())
}
pub fn sub_simd(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TensorError::invalid_shape_simple(
"Array lengths must match",
));
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
self.sub_simd_optimized(a, b, result)
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
self.sub_scalar(a, b, result)
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn sub_simd_optimized(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
let len = a.len();
let simd_len = (len / 4) * 4;
for i in (0..simd_len).step_by(4) {
unsafe {
let a_vec = v128_load(a.as_ptr().add(i) as *const v128);
let b_vec = v128_load(b.as_ptr().add(i) as *const v128);
let result_vec = f32x4_sub(a_vec, b_vec);
v128_store(result.as_mut_ptr().add(i) as *mut v128, result_vec);
}
}
for i in simd_len..len {
result[i] = a[i] - b[i];
}
Ok(())
}
fn sub_scalar(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
for (i, (a_val, b_val)) in a.iter().zip(b.iter()).enumerate() {
result[i] = a_val - b_val;
}
Ok(())
}
pub fn relu_simd(&self, input: &[f32], result: &mut [f32]) -> Result<()> {
if input.len() != result.len() {
return Err(TensorError::invalid_shape_simple(
"Array lengths must match",
));
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
self.relu_simd_optimized(input, result)
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
self.relu_scalar(input, result)
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn relu_simd_optimized(&self, input: &[f32], result: &mut [f32]) -> Result<()> {
let len = input.len();
let simd_len = (len / 4) * 4;
let zero_vec = f32x4_splat(0.0);
for i in (0..simd_len).step_by(4) {
unsafe {
let input_vec = v128_load(input.as_ptr().add(i) as *const v128);
let result_vec = f32x4_max(input_vec, zero_vec);
v128_store(result.as_mut_ptr().add(i) as *mut v128, result_vec);
}
}
for i in simd_len..len {
result[i] = input[i].max(0.0);
}
Ok(())
}
fn relu_scalar(&self, input: &[f32], result: &mut [f32]) -> Result<()> {
for (i, &val) in input.iter().enumerate() {
result[i] = val.max(0.0);
}
Ok(())
}
pub fn matmul_wasm(
&self,
a: &[f32],
b: &[f32],
result: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> Result<()> {
if a.len() != m * k || b.len() != k * n || result.len() != m * n {
return Err(TensorError::invalid_shape_simple(
"Matrix dimensions don't match",
));
}
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for l in 0..k {
sum += a[i * k + l] * b[l * n + j];
}
result[i * n + j] = sum;
}
}
Ok(())
}
}
#[cfg(target_arch = "wasm32")]
impl Default for WasmTensorOps {
fn default() -> Self {
Self::new()
}
}