use std::collections::HashMap;
use anyhow::{Result, anyhow};
use ronn_core::{CompiledKernel, DataType, KernelStats, MemoryUsage, Tensor, TensorLayout};
#[derive(Debug, Clone)]
pub struct WasmSimd128Ops;
impl WasmSimd128Ops {
pub fn is_simd_available() -> bool {
cfg!(target_feature = "simd128") || cfg!(not(target_arch = "wasm32"))
}
#[cfg(target_arch = "wasm32")]
pub fn simd_add_f32(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(anyhow!("Array length mismatch in SIMD add"));
}
let simd_len = a.len() & !3;
for i in (0..simd_len).step_by(4) {
unsafe {
use core::arch::wasm32::*;
let va = v128_load(a.as_ptr().add(i) as *const v128);
let vb = v128_load(b.as_ptr().add(i) as *const v128);
let vresult = f32x4_add(va, vb);
v128_store(result.as_mut_ptr().add(i) as *mut v128, vresult);
}
}
for i in simd_len..a.len() {
result[i] = a[i] + b[i];
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn simd_add_f32(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(anyhow!("Array length mismatch in add"));
}
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn simd_mul_f32(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(anyhow!("Array length mismatch in SIMD mul"));
}
let simd_len = a.len() & !3;
for i in (0..simd_len).step_by(4) {
unsafe {
use core::arch::wasm32::*;
let va = v128_load(a.as_ptr().add(i) as *const v128);
let vb = v128_load(b.as_ptr().add(i) as *const v128);
let vresult = f32x4_mul(va, vb);
v128_store(result.as_mut_ptr().add(i) as *mut v128, vresult);
}
}
for i in simd_len..a.len() {
result[i] = a[i] * b[i];
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn simd_mul_f32(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
for i in 0..a.len() {
result[i] = a[i] * b[i];
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn simd_relu_f32(input: &[f32], output: &mut [f32]) -> Result<()> {
if input.len() != output.len() {
return Err(anyhow!("Array length mismatch in SIMD ReLU"));
}
let simd_len = input.len() & !3;
for i in (0..simd_len).step_by(4) {
unsafe {
use core::arch::wasm32::*;
let vinput = v128_load(input.as_ptr().add(i) as *const v128);
let vzeros = f32x4_splat(0.0);
let vresult = f32x4_pmax(vinput, vzeros); v128_store(output.as_mut_ptr().add(i) as *mut v128, vresult);
}
}
for i in simd_len..input.len() {
output[i] = input[i].max(0.0);
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn simd_relu_f32(input: &[f32], output: &mut [f32]) -> Result<()> {
for i in 0..input.len() {
output[i] = input[i].max(0.0);
}
Ok(())
}
pub fn simd_matmul_f32(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> Result<()> {
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(anyhow!("Matrix dimension mismatch"));
}
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
let simd_k = k & !3;
#[cfg(target_arch = "wasm32")]
{
let mut vsum = unsafe { core::arch::wasm32::f32x4_splat(0.0) };
for l in (0..simd_k).step_by(4) {
unsafe {
use core::arch::wasm32::*;
let va = v128_load(a.as_ptr().add(i * k + l) as *const v128);
let vb = v128_load(b.as_ptr().add(l * n + j) as *const v128);
let vprod = f32x4_mul(va, vb);
vsum = f32x4_add(vsum, vprod);
}
}
let sum_array = unsafe { core::mem::transmute::<_, [f32; 4]>(vsum) };
sum = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
}
for l in simd_k..k {
sum += a[i * k + l] * b[l * n + j];
}
c[i * n + j] = sum;
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct WasmKernel {
op_type: String,
kernel_fn: fn(&WasmKernel, &[Tensor]) -> Result<Vec<Tensor>>,
stats: KernelStats,
memory_usage: MemoryUsage,
config: HashMap<String, f64>,
}
impl WasmKernel {
pub fn new(op_type: &str) -> Self {
let kernel_fn = match op_type {
"Add" => Self::execute_add,
"Mul" => Self::execute_mul,
"MatMul" => Self::execute_matmul,
"ReLU" => Self::execute_relu,
"Sigmoid" => Self::execute_sigmoid,
"Softmax" => Self::execute_softmax,
_ => Self::execute_fallback,
};
Self {
op_type: op_type.to_string(),
kernel_fn,
stats: KernelStats {
execution_count: 0,
average_time_us: 0.0,
min_time_us: 0.0,
max_time_us: 0.0,
},
memory_usage: MemoryUsage {
peak_bytes: 0,
current_bytes: 0,
allocation_count: 0,
},
config: HashMap::new(),
}
}
fn execute_add(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(anyhow!("Add operation requires exactly 2 inputs"));
}
let a = &inputs[0];
let b = &inputs[1];
if a.shape() != b.shape() {
return Err(anyhow!(
"Shape mismatch for Add: {:?} vs {:?}",
a.shape(),
b.shape()
));
}
let a_data = a.to_vec()?;
let b_data = b.to_vec()?;
let mut result_data = vec![0.0f32; a_data.len()];
WasmSimd128Ops::simd_add_f32(&a_data, &b_data, &mut result_data)?;
let result = Tensor::from_data(
result_data,
a.shape().to_vec(),
DataType::F32,
TensorLayout::RowMajor,
)?;
Ok(vec![result])
}
fn execute_mul(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(anyhow!("Mul operation requires exactly 2 inputs"));
}
let a = &inputs[0];
let b = &inputs[1];
if a.shape() != b.shape() {
return Err(anyhow!(
"Shape mismatch for Mul: {:?} vs {:?}",
a.shape(),
b.shape()
));
}
let a_data = a.to_vec()?;
let b_data = b.to_vec()?;
let mut result_data = vec![0.0f32; a_data.len()];
WasmSimd128Ops::simd_mul_f32(&a_data, &b_data, &mut result_data)?;
let result = Tensor::from_data(
result_data,
a.shape().to_vec(),
DataType::F32,
TensorLayout::RowMajor,
)?;
Ok(vec![result])
}
fn execute_matmul(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(anyhow!("MatMul operation requires exactly 2 inputs"));
}
let a = &inputs[0];
let b = &inputs[1];
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(anyhow!("MatMul requires 2D tensors"));
}
let m = a_shape[0];
let k = a_shape[1];
let n = b_shape[1];
if k != b_shape[0] {
return Err(anyhow!(
"Matrix dimension mismatch: {} != {}",
k,
b_shape[0]
));
}
let a_data = a.to_vec()?;
let b_data = b.to_vec()?;
let mut result_data = vec![0.0f32; m * n];
WasmSimd128Ops::simd_matmul_f32(&a_data, &b_data, &mut result_data, m, n, k)?;
let result = Tensor::from_data(
result_data,
vec![m, n],
DataType::F32,
TensorLayout::RowMajor,
)?;
Ok(vec![result])
}
fn execute_relu(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(anyhow!("ReLU operation requires exactly 1 input"));
}
let input = &inputs[0];
let input_data = input.to_vec()?;
let mut result_data = vec![0.0f32; input_data.len()];
WasmSimd128Ops::simd_relu_f32(&input_data, &mut result_data)?;
let result = Tensor::from_data(
result_data,
input.shape().to_vec(),
DataType::F32,
TensorLayout::RowMajor,
)?;
Ok(vec![result])
}
fn execute_sigmoid(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(anyhow!("Sigmoid operation requires exactly 1 input"));
}
let input = &inputs[0];
let input_data = input.to_vec()?;
let result_data: Vec<f32> = input_data
.iter()
.map(|&x| 1.0 / (1.0 + (-x).exp()))
.collect();
let result = Tensor::from_data(
result_data,
input.shape().to_vec(),
DataType::F32,
TensorLayout::RowMajor,
)?;
Ok(vec![result])
}
fn execute_softmax(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != 1 {
return Err(anyhow!("Softmax operation requires exactly 1 input"));
}
let input = &inputs[0];
let data = input.to_vec()?;
let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_values: Vec<f32> = data.iter().map(|&x| (x - max_val).exp()).collect();
let sum_exp: f32 = exp_values.iter().sum();
let result_data: Vec<f32> = exp_values.iter().map(|&x| x / sum_exp).collect();
let result = Tensor::from_data(
result_data,
input.shape().to_vec(),
DataType::F32,
TensorLayout::RowMajor,
)?;
Ok(vec![result])
}
fn execute_fallback(&self, _inputs: &[Tensor]) -> Result<Vec<Tensor>> {
Err(anyhow!(
"Operation {} not implemented for WASM",
self.op_type
))
}
fn update_stats(&mut self, execution_time_us: f64) {
self.stats.execution_count += 1;
if self.stats.execution_count == 1 {
self.stats.min_time_us = execution_time_us;
self.stats.max_time_us = execution_time_us;
self.stats.average_time_us = execution_time_us;
} else {
self.stats.min_time_us = self.stats.min_time_us.min(execution_time_us);
self.stats.max_time_us = self.stats.max_time_us.max(execution_time_us);
let n = self.stats.execution_count as f64;
self.stats.average_time_us =
((n - 1.0) * self.stats.average_time_us + execution_time_us) / n;
}
}
}
impl CompiledKernel for WasmKernel {
fn execute(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
let start_time = std::time::Instant::now();
let results = (self.kernel_fn)(self, inputs)?;
let execution_time = start_time.elapsed().as_micros() as f64;
Ok(results)
}
fn get_memory_usage(&self) -> MemoryUsage {
self.memory_usage.clone()
}
fn get_performance_stats(&self) -> KernelStats {
self.stats.clone()
}
}
pub fn create_wasm_kernel(op_type: &str) -> WasmKernel {
WasmKernel::new(op_type)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_availability() {
let _available = WasmSimd128Ops::is_simd_available();
}
#[test]
fn test_simd_add() -> Result<()> {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let mut result = vec![0.0; 5];
WasmSimd128Ops::simd_add_f32(&a, &b, &mut result)?;
assert_eq!(result, vec![3.0, 5.0, 7.0, 9.0, 11.0]);
Ok(())
}
#[test]
fn test_simd_mul() -> Result<()> {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 2.0, 2.0, 2.0];
let mut result = vec![0.0; 4];
WasmSimd128Ops::simd_mul_f32(&a, &b, &mut result)?;
assert_eq!(result, vec![2.0, 4.0, 6.0, 8.0]);
Ok(())
}
#[test]
fn test_simd_relu() -> Result<()> {
let input = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
let mut output = vec![0.0; 5];
WasmSimd128Ops::simd_relu_f32(&input, &mut output)?;
assert_eq!(output, vec![0.0, 0.0, 0.0, 0.5, 1.0]);
Ok(())
}
#[test]
fn test_wasm_kernel_add() -> Result<()> {
let kernel = create_wasm_kernel("Add");
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0],
vec![3],
DataType::F32,
TensorLayout::RowMajor,
)?;
let b = Tensor::from_data(
vec![4.0, 5.0, 6.0],
vec![3],
DataType::F32,
TensorLayout::RowMajor,
)?;
let results = kernel.execute(&[a, b])?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].to_vec().unwrap(), vec![5.0, 7.0, 9.0]);
Ok(())
}
#[test]
fn test_wasm_kernel_matmul() -> Result<()> {
let kernel = create_wasm_kernel("MatMul");
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
DataType::F32,
TensorLayout::RowMajor,
)?;
let b = Tensor::from_data(
vec![5.0, 6.0, 7.0, 8.0],
vec![2, 2],
DataType::F32,
TensorLayout::RowMajor,
)?;
let results = kernel.execute(&[a, b])?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].shape(), &[2, 2]);
Ok(())
}
#[test]
fn test_wasm_kernel_relu() -> Result<()> {
let kernel = create_wasm_kernel("ReLU");
let input = Tensor::from_data(
vec![-1.0, 0.0, 1.0, -2.0, 3.0],
vec![5],
DataType::F32,
TensorLayout::RowMajor,
)?;
let results = kernel.execute(&[input])?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].to_vec().unwrap(), vec![0.0, 0.0, 1.0, 0.0, 3.0]);
Ok(())
}
#[test]
fn test_unsupported_operation() {
let kernel = create_wasm_kernel("UnsupportedOp");
let input =
Tensor::from_data(vec![1.0], vec![1], DataType::F32, TensorLayout::RowMajor).unwrap();
let result = kernel.execute(&[input]);
assert!(result.is_err());
}
}