use crate::cpu::error::CpuResult;
use torsh_core::error::TorshError;
#[cfg(target_arch = "wasm32")]
use core::arch::wasm32::*;
#[derive(Debug)]
pub struct WasmSimdOps {
simd_available: bool,
vector_width: usize,
}
impl Default for WasmSimdOps {
fn default() -> Self {
Self::new()
}
}
impl WasmSimdOps {
pub fn new() -> Self {
let simd_available = Self::detect_wasm_simd();
Self {
simd_available,
vector_width: if simd_available { 16 } else { 1 }, }
}
#[cfg(target_arch = "wasm32")]
fn detect_wasm_simd() -> bool {
std::panic::catch_unwind(|| {
unsafe {
let _test_vector = v128_const(1, 2, 3, 4);
true
}
})
.unwrap_or(false)
}
#[cfg(not(target_arch = "wasm32"))]
fn detect_wasm_simd() -> bool {
false
}
pub fn is_available(&self) -> bool {
self.simd_available
}
pub fn vector_width(&self) -> usize {
self.vector_width
}
#[cfg(target_arch = "wasm32")]
pub fn add_f32(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> CpuResult<()> {
if !self.simd_available {
return self.add_f32_scalar(a, b, result);
}
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
unsafe {
let chunks = a.len() / 4; let remainder = a.len() % 4;
for i in 0..chunks {
let idx = i * 4;
let va = v128_load(a.as_ptr().add(idx) as *const v128);
let vb = v128_load(b.as_ptr().add(idx) as *const v128);
let vresult = f32x4_add(va, vb);
v128_store(result.as_mut_ptr().add(idx) as *mut v128, vresult);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
result[i] = a[i] + b[i];
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn add_f32(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> CpuResult<()> {
self.add_f32_scalar(a, b, result)
}
fn add_f32_scalar(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn mul_f32(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> CpuResult<()> {
if !self.simd_available {
return self.mul_f32_scalar(a, b, result);
}
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
unsafe {
let chunks = a.len() / 4;
let remainder = a.len() % 4;
for i in 0..chunks {
let idx = i * 4;
let va = v128_load(a.as_ptr().add(idx) as *const v128);
let vb = v128_load(b.as_ptr().add(idx) as *const v128);
let vresult = f32x4_mul(va, vb);
v128_store(result.as_mut_ptr().add(idx) as *mut v128, vresult);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
result[i] = a[i] * b[i];
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn mul_f32(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> CpuResult<()> {
self.mul_f32_scalar(a, b, result)
}
fn mul_f32_scalar(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
for i in 0..a.len() {
result[i] = a[i] * b[i];
}
Ok(())
}
pub fn dot_product_f32(&self, a: &[f32], b: &[f32]) -> CpuResult<f32> {
if a.len() != b.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
if a.is_empty() {
return Ok(0.0);
}
let mut result = 0.0f32;
#[cfg(target_arch = "wasm32")]
if self.simd_available {
unsafe {
let chunks = a.len() / 4;
let remainder = a.len() % 4;
let mut sum_vec = f32x4_splat(0.0);
for i in 0..chunks {
let idx = i * 4;
let va = v128_load(a.as_ptr().add(idx) as *const v128);
let vb = v128_load(b.as_ptr().add(idx) as *const v128);
let vmul = f32x4_mul(va, vb);
sum_vec = f32x4_add(sum_vec, vmul);
}
let sum_array = [
f32x4_extract_lane::<0>(sum_vec),
f32x4_extract_lane::<1>(sum_vec),
f32x4_extract_lane::<2>(sum_vec),
f32x4_extract_lane::<3>(sum_vec),
];
result = sum_array.iter().sum();
for i in (chunks * 4)..(chunks * 4 + remainder) {
result += a[i] * b[i];
}
}
} else {
for i in 0..a.len() {
result += a[i] * b[i];
}
}
#[cfg(not(target_arch = "wasm32"))]
{
for i in 0..a.len() {
result += a[i] * b[i];
}
}
Ok(result)
}
pub fn matmul_f32(
&self,
a: &[f32],
b: &[f32],
result: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> CpuResult<()> {
if a.len() != m * k || b.len() != k * n || result.len() != m * n {
return Err(TorshError::ComputeError(
"Invalid matrix dimensions".to_string(),
));
}
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for l in 0..k {
sum += a[i * k + l] * b[l * n + j];
}
result[i * n + j] = sum;
}
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn relu_f32(&self, input: &[f32], output: &mut [f32]) -> CpuResult<()> {
if input.len() != output.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
if !self.simd_available {
return self.relu_f32_scalar(input, output);
}
unsafe {
let chunks = input.len() / 4;
let remainder = input.len() % 4;
let zero_vec = f32x4_splat(0.0);
for i in 0..chunks {
let idx = i * 4;
let input_vec = v128_load(input.as_ptr().add(idx) as *const v128);
let result_vec = f32x4_pmax(input_vec, zero_vec);
v128_store(output.as_mut_ptr().add(idx) as *mut v128, result_vec);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
output[i] = input[i].max(0.0);
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn relu_f32(&self, input: &[f32], output: &mut [f32]) -> CpuResult<()> {
self.relu_f32_scalar(input, output)
}
fn relu_f32_scalar(&self, input: &[f32], output: &mut [f32]) -> CpuResult<()> {
if input.len() != output.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
for i in 0..input.len() {
output[i] = input[i].max(0.0);
}
Ok(())
}
pub fn softmax_f32(&self, input: &[f32], output: &mut [f32]) -> CpuResult<()> {
if input.len() != output.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
if input.is_empty() {
return Ok(());
}
let max_val = input.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum = 0.0f32;
for i in 0..input.len() {
output[i] = (input[i] - max_val).exp();
sum += output[i];
}
for i in 0..output.len() {
output[i] /= sum;
}
Ok(())
}
pub fn optimize_memory_access<T>(&self, data: &mut [T]) -> CpuResult<()> {
let _ = data; Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn fma_f32(&self, a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != c.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
if !self.simd_available {
return self.fma_f32_scalar(a, b, c, result);
}
unsafe {
let chunks = a.len() / 4;
let remainder = a.len() % 4;
for i in 0..chunks {
let idx = i * 4;
let va = v128_load(a.as_ptr().add(idx) as *const v128);
let vb = v128_load(b.as_ptr().add(idx) as *const v128);
let vc = v128_load(c.as_ptr().add(idx) as *const v128);
let vmul = f32x4_mul(va, vb);
let vresult = f32x4_add(vmul, vc);
v128_store(result.as_mut_ptr().add(idx) as *mut v128, vresult);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
result[i] = a[i].mul_add(b[i], c[i]);
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn fma_f32(&self, a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) -> CpuResult<()> {
self.fma_f32_scalar(a, b, c, result)
}
fn fma_f32_scalar(&self, a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != c.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
for i in 0..a.len() {
result[i] = a[i].mul_add(b[i], c[i]);
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn check_browser_compatibility(&self) -> BrowserCompatibility {
BrowserCompatibility {
wasm_simd: self.simd_available,
wasm_threads: Self::detect_wasm_threads(),
shared_array_buffer: Self::detect_shared_array_buffer(),
offscreen_canvas: Self::detect_offscreen_canvas(),
web_workers: Self::detect_web_workers(),
webgl2: Self::detect_webgl2(),
estimated_memory_limit_mb: Self::estimate_memory_limit(),
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn check_browser_compatibility(&self) -> BrowserCompatibility {
BrowserCompatibility::default()
}
#[cfg(target_arch = "wasm32")]
fn detect_wasm_threads() -> bool {
false
}
#[cfg(target_arch = "wasm32")]
fn detect_shared_array_buffer() -> bool {
std::panic::catch_unwind(|| {
false
})
.unwrap_or(false)
}
#[cfg(target_arch = "wasm32")]
fn detect_offscreen_canvas() -> bool {
true }
#[cfg(target_arch = "wasm32")]
fn detect_web_workers() -> bool {
true
}
#[cfg(target_arch = "wasm32")]
fn detect_webgl2() -> bool {
true }
#[cfg(target_arch = "wasm32")]
fn estimate_memory_limit() -> usize {
if cfg!(target_pointer_width = "64") {
2048 } else {
512 }
}
pub fn optimize_for_browser(&mut self, memory_limit_mb: Option<usize>) -> CpuResult<()> {
let limit = memory_limit_mb.unwrap_or(512);
if limit < 256 {
self.vector_width = self.vector_width.min(8);
} else if limit < 1024 {
self.vector_width = self.vector_width.min(16);
}
Ok(())
}
pub fn matmul_f32_blocked(
&self,
a: &[f32],
b: &[f32],
result: &mut [f32],
m: usize,
n: usize,
k: usize,
block_size: usize,
) -> CpuResult<()> {
if a.len() != m * k || b.len() != k * n || result.len() != m * n {
return Err(TorshError::ComputeError(
"Invalid matrix dimensions".to_string(),
));
}
result.fill(0.0);
for i0 in (0..m).step_by(block_size) {
for j0 in (0..n).step_by(block_size) {
for k0 in (0..k).step_by(block_size) {
let i_end = (i0 + block_size).min(m);
let j_end = (j0 + block_size).min(n);
let k_end = (k0 + block_size).min(k);
for i in i0..i_end {
for j in j0..j_end {
let mut sum = 0.0f32;
#[cfg(target_arch = "wasm32")]
if self.simd_available && (k_end - k0) >= 4 {
unsafe {
let simd_end = k0 + ((k_end - k0) / 4) * 4;
let mut sum_vec = f32x4_splat(0.0);
for l in (k0..simd_end).step_by(4) {
let va =
v128_load(a.as_ptr().add(i * k + l) as *const v128);
let vb_vals = [
b[l * n + j],
b[(l + 1) * n + j],
b[(l + 2) * n + j],
b[(l + 3) * n + j],
];
let vb =
f32x4(vb_vals[0], vb_vals[1], vb_vals[2], vb_vals[3]);
let vmul = f32x4_mul(va, vb);
sum_vec = f32x4_add(sum_vec, vmul);
}
sum += f32x4_extract_lane::<0>(sum_vec)
+ f32x4_extract_lane::<1>(sum_vec)
+ f32x4_extract_lane::<2>(sum_vec)
+ f32x4_extract_lane::<3>(sum_vec);
for l in simd_end..k_end {
sum += a[i * k + l] * b[l * n + j];
}
}
} else {
for l in k0..k_end {
sum += a[i * k + l] * b[l * n + j];
}
}
#[cfg(not(target_arch = "wasm32"))]
{
for l in k0..k_end {
sum += a[i * k + l] * b[l * n + j];
}
}
result[i * n + j] += sum;
}
}
}
}
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn sum_f32(&self, input: &[f32]) -> CpuResult<f32> {
if input.is_empty() {
return Ok(0.0);
}
if !self.simd_available {
return Ok(input.iter().sum());
}
unsafe {
let chunks = input.len() / 4;
let remainder = input.len() % 4;
let mut sum_vec = f32x4_splat(0.0);
for i in 0..chunks {
let idx = i * 4;
let input_vec = v128_load(input.as_ptr().add(idx) as *const v128);
sum_vec = f32x4_add(sum_vec, input_vec);
}
let mut result = f32x4_extract_lane::<0>(sum_vec)
+ f32x4_extract_lane::<1>(sum_vec)
+ f32x4_extract_lane::<2>(sum_vec)
+ f32x4_extract_lane::<3>(sum_vec);
for i in (chunks * 4)..(chunks * 4 + remainder) {
result += input[i];
}
Ok(result)
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn sum_f32(&self, input: &[f32]) -> CpuResult<f32> {
Ok(input.iter().sum())
}
#[cfg(target_arch = "wasm32")]
pub fn add_i32(&self, a: &[i32], b: &[i32], result: &mut [i32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
if !self.simd_available {
return self.add_i32_scalar(a, b, result);
}
unsafe {
let chunks = a.len() / 4;
let remainder = a.len() % 4;
for i in 0..chunks {
let idx = i * 4;
let va = v128_load(a.as_ptr().add(idx) as *const v128);
let vb = v128_load(b.as_ptr().add(idx) as *const v128);
let vresult = i32x4_add(va, vb);
v128_store(result.as_mut_ptr().add(idx) as *mut v128, vresult);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
result[i] = a[i] + b[i];
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn add_i32(&self, a: &[i32], b: &[i32], result: &mut [i32]) -> CpuResult<()> {
self.add_i32_scalar(a, b, result)
}
fn add_i32_scalar(&self, a: &[i32], b: &[i32], result: &mut [i32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn mul_i32(&self, a: &[i32], b: &[i32], result: &mut [i32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
if !self.simd_available {
return self.mul_i32_scalar(a, b, result);
}
unsafe {
let chunks = a.len() / 4;
let remainder = a.len() % 4;
for i in 0..chunks {
let idx = i * 4;
let va = v128_load(a.as_ptr().add(idx) as *const v128);
let vb = v128_load(b.as_ptr().add(idx) as *const v128);
let vresult = i32x4_mul(va, vb);
v128_store(result.as_mut_ptr().add(idx) as *mut v128, vresult);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
result[i] = a[i] * b[i];
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn mul_i32(&self, a: &[i32], b: &[i32], result: &mut [i32]) -> CpuResult<()> {
self.mul_i32_scalar(a, b, result)
}
fn mul_i32_scalar(&self, a: &[i32], b: &[i32], result: &mut [i32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
for i in 0..a.len() {
result[i] = a[i] * b[i];
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn max_f32(&self, input: &[f32]) -> CpuResult<f32> {
if input.is_empty() {
return Err(TorshError::ComputeError(
"Cannot find max of empty array".to_string(),
));
}
if !self.simd_available {
return Ok(input.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)));
}
unsafe {
let chunks = input.len() / 4;
let remainder = input.len() % 4;
let mut max_vec = if chunks > 0 {
v128_load(input.as_ptr() as *const v128)
} else {
f32x4_splat(input[0])
};
for i in 1..chunks {
let idx = i * 4;
let input_vec = v128_load(input.as_ptr().add(idx) as *const v128);
max_vec = f32x4_pmax(max_vec, input_vec);
}
let mut result = f32x4_extract_lane::<0>(max_vec)
.max(f32x4_extract_lane::<1>(max_vec))
.max(f32x4_extract_lane::<2>(max_vec))
.max(f32x4_extract_lane::<3>(max_vec));
for i in (chunks * 4)..(chunks * 4 + remainder) {
result = result.max(input[i]);
}
Ok(result)
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn max_f32(&self, input: &[f32]) -> CpuResult<f32> {
if input.is_empty() {
return Err(TorshError::ComputeError(
"Cannot find max of empty array".to_string(),
));
}
Ok(input.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)))
}
#[cfg(target_arch = "wasm32")]
pub fn min_f32(&self, input: &[f32]) -> CpuResult<f32> {
if input.is_empty() {
return Err(TorshError::ComputeError(
"Cannot find min of empty array".to_string(),
));
}
if !self.simd_available {
return Ok(input.iter().fold(f32::INFINITY, |a, &b| a.min(b)));
}
unsafe {
let chunks = input.len() / 4;
let remainder = input.len() % 4;
let mut min_vec = if chunks > 0 {
v128_load(input.as_ptr() as *const v128)
} else {
f32x4_splat(input[0])
};
for i in 1..chunks {
let idx = i * 4;
let input_vec = v128_load(input.as_ptr().add(idx) as *const v128);
min_vec = f32x4_pmin(min_vec, input_vec);
}
let mut result = f32x4_extract_lane::<0>(min_vec)
.min(f32x4_extract_lane::<1>(min_vec))
.min(f32x4_extract_lane::<2>(min_vec))
.min(f32x4_extract_lane::<3>(min_vec));
for i in (chunks * 4)..(chunks * 4 + remainder) {
result = result.min(input[i]);
}
Ok(result)
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn min_f32(&self, input: &[f32]) -> CpuResult<f32> {
if input.is_empty() {
return Err(TorshError::ComputeError(
"Cannot find min of empty array".to_string(),
));
}
Ok(input.iter().fold(f32::INFINITY, |a, &b| a.min(b)))
}
#[cfg(target_arch = "wasm32")]
pub fn greater_than_f32(&self, a: &[f32], b: &[f32], result: &mut [u32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
if !self.simd_available {
return self.greater_than_f32_scalar(a, b, result);
}
unsafe {
let chunks = a.len() / 4;
let remainder = a.len() % 4;
for i in 0..chunks {
let idx = i * 4;
let va = v128_load(a.as_ptr().add(idx) as *const v128);
let vb = v128_load(b.as_ptr().add(idx) as *const v128);
let mask = f32x4_gt(va, vb);
v128_store(result.as_mut_ptr().add(idx) as *mut v128, mask);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
result[i] = if a[i] > b[i] { u32::MAX } else { 0 };
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn greater_than_f32(&self, a: &[f32], b: &[f32], result: &mut [u32]) -> CpuResult<()> {
self.greater_than_f32_scalar(a, b, result)
}
fn greater_than_f32_scalar(&self, a: &[f32], b: &[f32], result: &mut [u32]) -> CpuResult<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
for i in 0..a.len() {
result[i] = if a[i] > b[i] { u32::MAX } else { 0 };
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn select_f32(
&self,
mask: &[u32],
a: &[f32],
b: &[f32],
result: &mut [f32],
) -> CpuResult<()> {
if mask.len() != a.len() || a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
if !self.simd_available {
return self.select_f32_scalar(mask, a, b, result);
}
unsafe {
let chunks = a.len() / 4;
let remainder = a.len() % 4;
for i in 0..chunks {
let idx = i * 4;
let vmask = v128_load(mask.as_ptr().add(idx) as *const v128);
let va = v128_load(a.as_ptr().add(idx) as *const v128);
let vb = v128_load(b.as_ptr().add(idx) as *const v128);
let vresult = v128_bitselect(va, vb, vmask);
v128_store(result.as_mut_ptr().add(idx) as *mut v128, vresult);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
result[i] = if mask[i] != 0 { a[i] } else { b[i] };
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn select_f32(
&self,
mask: &[u32],
a: &[f32],
b: &[f32],
result: &mut [f32],
) -> CpuResult<()> {
self.select_f32_scalar(mask, a, b, result)
}
fn select_f32_scalar(
&self,
mask: &[u32],
a: &[f32],
b: &[f32],
result: &mut [f32],
) -> CpuResult<()> {
if mask.len() != a.len() || a.len() != b.len() || a.len() != result.len() {
return Err(TorshError::ComputeError(
"Array length mismatch".to_string(),
));
}
for i in 0..a.len() {
result[i] = if mask[i] != 0 { a[i] } else { b[i] };
}
Ok(())
}
pub fn get_performance_info(&self) -> WasmSimdPerformanceInfo {
WasmSimdPerformanceInfo {
simd_available: self.simd_available,
vector_width: self.vector_width,
estimated_speedup: if self.simd_available { 3.0 } else { 1.0 }, supports_f32: true,
supports_f64: self.simd_available && cfg!(target_feature = "simd128"), supports_i32: self.simd_available,
memory_bandwidth_gbps: if self.simd_available { 15.0 } else { 8.0 },
browser_compatibility: self.check_browser_compatibility(),
}
}
}
#[derive(Debug, Clone)]
pub struct BrowserCompatibility {
pub wasm_simd: bool,
pub wasm_threads: bool,
pub shared_array_buffer: bool,
pub offscreen_canvas: bool,
pub web_workers: bool,
pub webgl2: bool,
pub estimated_memory_limit_mb: usize,
}
impl Default for BrowserCompatibility {
fn default() -> Self {
Self {
wasm_simd: false,
wasm_threads: false,
shared_array_buffer: false,
offscreen_canvas: false,
web_workers: false,
webgl2: false,
estimated_memory_limit_mb: 512,
}
}
}
#[derive(Debug, Clone)]
pub struct WasmSimdPerformanceInfo {
pub simd_available: bool,
pub vector_width: usize,
pub estimated_speedup: f32,
pub supports_f32: bool,
pub supports_f64: bool,
pub supports_i32: bool,
pub memory_bandwidth_gbps: f32,
pub browser_compatibility: BrowserCompatibility,
}
#[derive(Debug, Clone)]
pub struct WasmDeploymentConfig {
pub memory_limit_mb: usize,
pub aggressive_optimizations: bool,
pub use_web_workers: bool,
pub memory_efficient: bool,
pub matrix_block_size: usize,
pub debug_mode: bool,
}
impl Default for WasmDeploymentConfig {
fn default() -> Self {
Self {
memory_limit_mb: 512,
aggressive_optimizations: false,
use_web_workers: false,
memory_efficient: true,
matrix_block_size: 64,
debug_mode: false,
}
}
}
impl WasmDeploymentConfig {
pub fn mobile_optimized() -> Self {
Self {
memory_limit_mb: 256,
aggressive_optimizations: false,
use_web_workers: false,
memory_efficient: true,
matrix_block_size: 32,
debug_mode: false,
}
}
pub fn desktop_optimized() -> Self {
Self {
memory_limit_mb: 2048,
aggressive_optimizations: true,
use_web_workers: true,
memory_efficient: false,
matrix_block_size: 128,
debug_mode: false,
}
}
pub fn debug() -> Self {
Self {
memory_limit_mb: 1024,
aggressive_optimizations: false,
use_web_workers: false,
memory_efficient: true,
matrix_block_size: 64,
debug_mode: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wasm_simd_creation() {
let wasm_ops = WasmSimdOps::new();
let info = wasm_ops.get_performance_info();
#[cfg(target_arch = "wasm32")]
assert!(info.simd_available);
#[cfg(not(target_arch = "wasm32"))]
assert!(!info.simd_available);
assert!(info.supports_f32);
assert!(info.estimated_speedup >= 1.0);
}
#[test]
fn test_add_f32() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let mut result = vec![0.0; 5];
let res = wasm_ops.add_f32(&a, &b, &mut result);
assert!(res.is_ok());
let expected = vec![3.0, 5.0, 7.0, 9.0, 11.0];
assert_eq!(result, expected);
}
#[test]
fn test_mul_f32() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 3.0, 4.0, 5.0];
let mut result = vec![0.0; 4];
let res = wasm_ops.mul_f32(&a, &b, &mut result);
assert!(res.is_ok());
let expected = vec![2.0, 6.0, 12.0, 20.0];
assert_eq!(result, expected);
}
#[test]
fn test_dot_product() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 3.0, 4.0, 5.0];
let result = wasm_ops
.dot_product_f32(&a, &b)
.expect("dot product computation should succeed");
let expected = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0; assert!((result - expected).abs() < 1e-6);
}
#[test]
fn test_relu() {
let wasm_ops = WasmSimdOps::new();
let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
let mut output = vec![0.0; 5];
let res = wasm_ops.relu_f32(&input, &mut output);
assert!(res.is_ok());
let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
assert_eq!(output, expected);
}
#[test]
fn test_softmax() {
let wasm_ops = WasmSimdOps::new();
let input = vec![1.0, 2.0, 3.0];
let mut output = vec![0.0; 3];
let res = wasm_ops.softmax_f32(&input, &mut output);
assert!(res.is_ok());
let sum: f32 = output.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
for &val in &output {
assert!(val >= 0.0 && val <= 1.0);
}
assert!(output[0] < output[1]);
assert!(output[1] < output[2]);
}
#[test]
fn test_matmul() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let mut result = vec![0.0; 4];
let res = wasm_ops.matmul_f32(&a, &b, &mut result, 2, 2, 3);
assert!(res.is_ok());
let expected = vec![22.0, 28.0, 49.0, 64.0];
assert_eq!(result, expected);
}
#[test]
fn test_array_length_mismatch() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0]; let mut result = vec![0.0; 2];
let res = wasm_ops.add_f32(&a, &b, &mut result);
assert!(res.is_err());
}
#[test]
fn test_performance_info() {
let wasm_ops = WasmSimdOps::new();
let info = wasm_ops.get_performance_info();
assert!(info.vector_width > 0);
assert!(info.estimated_speedup >= 1.0);
assert!(info.memory_bandwidth_gbps > 0.0);
}
#[test]
fn test_sum_f32() {
let wasm_ops = WasmSimdOps::new();
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = wasm_ops.sum_f32(&input).expect("f32 sum should succeed");
assert_eq!(result, 15.0);
let empty: Vec<f32> = vec![];
let empty_result = wasm_ops.sum_f32(&empty).expect("f32 sum should succeed");
assert_eq!(empty_result, 0.0);
}
#[test]
fn test_add_i32() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1, 2, 3, 4, 5];
let b = vec![2, 3, 4, 5, 6];
let mut result = vec![0; 5];
let res = wasm_ops.add_i32(&a, &b, &mut result);
assert!(res.is_ok());
let expected = vec![3, 5, 7, 9, 11];
assert_eq!(result, expected);
}
#[test]
fn test_mul_i32() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1, 2, 3, 4];
let b = vec![2, 3, 4, 5];
let mut result = vec![0; 4];
let res = wasm_ops.mul_i32(&a, &b, &mut result);
assert!(res.is_ok());
let expected = vec![2, 6, 12, 20];
assert_eq!(result, expected);
}
#[test]
fn test_max_f32() {
let wasm_ops = WasmSimdOps::new();
let input = vec![1.0, 5.0, 2.0, 8.0, 3.0];
let result = wasm_ops.max_f32(&input).expect("f32 max should succeed");
assert_eq!(result, 8.0);
let single = vec![42.0];
let single_result = wasm_ops.max_f32(&single).expect("f32 max should succeed");
assert_eq!(single_result, 42.0);
let empty: Vec<f32> = vec![];
let empty_result = wasm_ops.max_f32(&empty);
assert!(empty_result.is_err());
}
#[test]
fn test_min_f32() {
let wasm_ops = WasmSimdOps::new();
let input = vec![5.0, 1.0, 8.0, 2.0, 3.0];
let result = wasm_ops.min_f32(&input).expect("f32 min should succeed");
assert_eq!(result, 1.0);
let input_neg = vec![-5.0, -1.0, -8.0, -2.0];
let result_neg = wasm_ops
.min_f32(&input_neg)
.expect("f32 min should succeed");
assert_eq!(result_neg, -8.0);
}
#[test]
fn test_greater_than_f32() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1.0, 3.0, 5.0, 2.0];
let b = vec![2.0, 2.0, 4.0, 3.0];
let mut result = vec![0u32; 4];
let res = wasm_ops.greater_than_f32(&a, &b, &mut result);
assert!(res.is_ok());
let expected = vec![0, u32::MAX, u32::MAX, 0];
assert_eq!(result, expected);
}
#[test]
fn test_select_f32() {
let wasm_ops = WasmSimdOps::new();
let mask = vec![u32::MAX, 0, u32::MAX, 0]; let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![10.0, 20.0, 30.0, 40.0];
let mut result = vec![0.0; 4];
let res = wasm_ops.select_f32(&mask, &a, &b, &mut result);
assert!(res.is_ok());
let expected = vec![1.0, 20.0, 3.0, 40.0];
assert_eq!(result, expected);
}
#[test]
fn test_fma_f32() {
let wasm_ops = WasmSimdOps::new();
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 3.0, 4.0, 5.0];
let c = vec![1.0, 1.0, 1.0, 1.0];
let mut result = vec![0.0; 4];
let res = wasm_ops.fma_f32(&a, &b, &c, &mut result);
assert!(res.is_ok());
let expected = vec![3.0, 7.0, 13.0, 21.0];
assert_eq!(result, expected);
}
#[test]
fn test_large_arrays_simd() {
let wasm_ops = WasmSimdOps::new();
let size = 1000;
let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..size).map(|i| (i + 1) as f32).collect();
let mut result = vec![0.0; size];
let res = wasm_ops.add_f32(&a, &b, &mut result);
assert!(res.is_ok());
assert_eq!(result[0], 1.0); assert_eq!(result[1], 3.0); assert_eq!(result[size - 1], (2 * size - 1) as f32);
let sum_result = wasm_ops.sum_f32(&a).expect("f32 sum should succeed");
let expected: f32 = (0..size).map(|i| i as f32).sum();
assert!((sum_result - expected).abs() < 1e-6);
}
#[test]
fn test_deployment_configs() {
let mobile_config = WasmDeploymentConfig::mobile_optimized();
assert_eq!(mobile_config.memory_limit_mb, 256);
assert_eq!(mobile_config.matrix_block_size, 32);
assert!(!mobile_config.aggressive_optimizations);
let desktop_config = WasmDeploymentConfig::desktop_optimized();
assert_eq!(desktop_config.memory_limit_mb, 2048);
assert_eq!(desktop_config.matrix_block_size, 128);
assert!(desktop_config.aggressive_optimizations);
let debug_config = WasmDeploymentConfig::debug();
assert!(debug_config.debug_mode);
assert!(!debug_config.aggressive_optimizations);
}
#[test]
fn test_browser_compatibility() {
let wasm_ops = WasmSimdOps::new();
let compatibility = wasm_ops.check_browser_compatibility();
assert!(compatibility.estimated_memory_limit_mb > 0);
#[cfg(not(target_arch = "wasm32"))]
{
assert!(!compatibility.wasm_simd);
assert!(!compatibility.wasm_threads);
}
}
#[test]
fn test_blocked_matmul() {
let wasm_ops = WasmSimdOps::new();
let a = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
];
let b = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
]; let mut result = vec![0.0; 16];
let res = wasm_ops.matmul_f32_blocked(&a, &b, &mut result, 4, 4, 4, 2);
assert!(res.is_ok());
assert_eq!(result, a);
}
}