use crate::error::WasmError;
use wasm_bindgen::prelude::*;
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
use core::arch::wasm32::*;
#[wasm_bindgen]
pub fn simd_dot_product_f32(a: &[f32], b: &[f32]) -> Result<f32, JsValue> {
if a.len() != b.len() {
return Err(WasmError::InvalidParameter(format!(
"simd_dot_product_f32: length mismatch: {} vs {}",
a.len(),
b.len()
))
.into());
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
return Ok(dot_product_simd(a, b));
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
return Ok(dot_product_scalar(a, b));
}
}
#[wasm_bindgen]
pub fn simd_matrix_multiply_f32(a: &[f32], b: &[f32], n: usize) -> Result<Vec<f32>, JsValue> {
if n == 0 {
return Err(WasmError::InvalidParameter("n must be > 0".to_string()).into());
}
let expected = n.checked_mul(n).ok_or_else(|| {
WasmError::InvalidParameter("n² overflows usize".to_string())
})?;
if a.len() != expected || b.len() != expected {
return Err(WasmError::InvalidParameter(format!(
"simd_matrix_multiply_f32: expected {}×{} = {} elements, got a={} b={}",
n,
n,
expected,
a.len(),
b.len()
))
.into());
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
if n == 4 {
return Ok(matmul4x4_simd(a, b));
}
if n == 8 {
return Ok(matmul8x8_simd(a, b));
}
}
Ok(matmul_scalar(a, b, n))
}
#[wasm_bindgen]
pub fn simd_softmax_f32(input: &[f32]) -> Result<Vec<f32>, JsValue> {
if input.is_empty() {
return Err(WasmError::InvalidParameter(
"simd_softmax_f32: input must not be empty".to_string(),
)
.into());
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
return Ok(softmax_simd(input));
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
return Ok(softmax_scalar(input));
}
}
#[wasm_bindgen]
pub fn simd_relu_f32(input: &[f32]) -> Vec<f32> {
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
relu_simd(input)
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
relu_scalar(input)
}
}
#[wasm_bindgen]
pub fn simd_sigmoid_f32(input: &[f32]) -> Vec<f32> {
input.iter().map(|&x| 1.0_f32 / (1.0_f32 + (-x).exp())).collect()
}
#[wasm_bindgen]
pub fn simd_add_f32(a: &[f32], b: &[f32]) -> Result<Vec<f32>, JsValue> {
if a.len() != b.len() {
return Err(WasmError::InvalidParameter(format!(
"simd_add_f32: length mismatch: {} vs {}",
a.len(),
b.len()
))
.into());
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
return Ok(add_simd(a, b));
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
return Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect());
}
}
#[wasm_bindgen]
pub fn simd_mul_f32(a: &[f32], b: &[f32]) -> Result<Vec<f32>, JsValue> {
if a.len() != b.len() {
return Err(WasmError::InvalidParameter(format!(
"simd_mul_f32: length mismatch: {} vs {}",
a.len(),
b.len()
))
.into());
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
return Ok(mul_simd(a, b));
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
return Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect());
}
}
#[wasm_bindgen]
pub fn simd_l2_norm_f32(input: &[f32]) -> f32 {
let sq_sum = {
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
dot_product_simd(input, input)
}
#[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
{
dot_product_scalar(input, input)
}
};
sq_sum.sqrt()
}
#[wasm_bindgen]
pub fn simd_ops_available() -> bool {
#[cfg(target_feature = "simd128")]
{ true }
#[cfg(not(target_feature = "simd128"))]
{ false }
}
#[inline(always)]
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let mut acc0 = 0.0_f32;
let mut acc1 = 0.0_f32;
let mut acc2 = 0.0_f32;
let mut acc3 = 0.0_f32;
for i in 0..chunks {
let base = i * 4;
acc0 += a[base] * b[base];
acc1 += a[base + 1] * b[base + 1];
acc2 += a[base + 2] * b[base + 2];
acc3 += a[base + 3] * b[base + 3];
}
let mut remainder = acc0 + acc1 + acc2 + acc3;
for i in (chunks * 4)..n {
remainder += a[i] * b[i];
}
remainder
}
fn matmul_scalar(a: &[f32], b: &[f32], n: usize) -> Vec<f32> {
let mut c = vec![0.0_f32; n * n];
for row in 0..n {
for k in 0..n {
let aik = a[row * n + k];
for col in 0..n {
c[row * n + col] += aik * b[k * n + col];
}
}
}
c
}
fn softmax_scalar(input: &[f32]) -> Vec<f32> {
let max_val = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut exps: Vec<f32> = input.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum == 0.0 {
let n = exps.len() as f32;
exps.iter_mut().for_each(|v| *v = 1.0 / n);
} else {
exps.iter_mut().for_each(|v| *v /= sum);
}
exps
}
fn relu_scalar(input: &[f32]) -> Vec<f32> {
input.iter().map(|&x| x.max(0.0_f32)).collect()
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let mut acc = f32x4_splat(0.0);
for i in 0..chunks {
let base = i * 4;
let va = unsafe {
v128_load(a.as_ptr().add(base) as *const v128)
};
let vb = unsafe {
v128_load(b.as_ptr().add(base) as *const v128)
};
acc = f32x4_add(acc, f32x4_mul(va, vb));
}
let mut sum = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
for i in (chunks * 4)..n {
sum += a[i] * b[i];
}
sum
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn matmul4x4_simd(a: &[f32], b: &[f32]) -> Vec<f32> {
let mut c = vec![0.0_f32; 16];
for row in 0..4_usize {
let mut c_row = f32x4_splat(0.0);
for k in 0..4_usize {
let aik = f32x4_splat(a[row * 4 + k]);
let b_row = unsafe {
v128_load(b.as_ptr().add(k * 4) as *const v128)
};
c_row = f32x4_add(c_row, f32x4_mul(aik, b_row));
}
unsafe {
v128_store(c.as_mut_ptr().add(row * 4) as *mut v128, c_row);
}
}
c
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn matmul8x8_simd(a: &[f32], b: &[f32]) -> Vec<f32> {
let mut c = vec![0.0_f32; 64];
for row in 0..8_usize {
let mut c_lo = f32x4_splat(0.0); let mut c_hi = f32x4_splat(0.0);
for k in 0..8_usize {
let aik = f32x4_splat(a[row * 8 + k]);
let b_lo = unsafe {
v128_load(b.as_ptr().add(k * 8) as *const v128)
};
let b_hi = unsafe {
v128_load(b.as_ptr().add(k * 8 + 4) as *const v128)
};
c_lo = f32x4_add(c_lo, f32x4_mul(aik, b_lo));
c_hi = f32x4_add(c_hi, f32x4_mul(aik, b_hi));
}
unsafe {
v128_store(c.as_mut_ptr().add(row * 8) as *mut v128, c_lo);
v128_store(c.as_mut_ptr().add(row * 8 + 4) as *mut v128, c_hi);
}
}
c
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn softmax_simd(input: &[f32]) -> Vec<f32> {
let max_val = input.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let max_splat = f32x4_splat(max_val);
let n = input.len();
let chunks = n / 4;
let mut exps = vec![0.0_f32; n];
let mut sum = 0.0_f32;
for i in 0..chunks {
let base = i * 4;
let v = unsafe { v128_load(input.as_ptr().add(base) as *const v128) };
let shifted = f32x4_sub(v, max_splat);
let e0 = f32x4_extract_lane::<0>(shifted).exp();
let e1 = f32x4_extract_lane::<1>(shifted).exp();
let e2 = f32x4_extract_lane::<2>(shifted).exp();
let e3 = f32x4_extract_lane::<3>(shifted).exp();
exps[base] = e0;
exps[base + 1] = e1;
exps[base + 2] = e2;
exps[base + 3] = e3;
sum += e0 + e1 + e2 + e3;
}
for i in (chunks * 4)..n {
let e = (input[i] - max_val).exp();
exps[i] = e;
sum += e;
}
if sum == 0.0 {
let inv = 1.0_f32 / n as f32;
return vec![inv; n];
}
let inv_sum = f32x4_splat(1.0_f32 / sum);
for i in 0..chunks {
let base = i * 4;
let v = unsafe { v128_load(exps.as_ptr().add(base) as *const v128) };
let normed = f32x4_mul(v, inv_sum);
unsafe { v128_store(exps.as_mut_ptr().add(base) as *mut v128, normed) };
}
for i in (chunks * 4)..n {
exps[i] /= sum;
}
exps
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn relu_simd(input: &[f32]) -> Vec<f32> {
let n = input.len();
let chunks = n / 4;
let zero = f32x4_splat(0.0);
let mut out = vec![0.0_f32; n];
for i in 0..chunks {
let base = i * 4;
let v = unsafe { v128_load(input.as_ptr().add(base) as *const v128) };
let r = f32x4_max(v, zero);
unsafe { v128_store(out.as_mut_ptr().add(base) as *mut v128, r) };
}
for i in (chunks * 4)..n {
out[i] = input[i].max(0.0);
}
out
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn add_simd(a: &[f32], b: &[f32]) -> Vec<f32> {
let n = a.len();
let chunks = n / 4;
let mut out = vec![0.0_f32; n];
for i in 0..chunks {
let base = i * 4;
let va = unsafe { v128_load(a.as_ptr().add(base) as *const v128) };
let vb = unsafe { v128_load(b.as_ptr().add(base) as *const v128) };
let r = f32x4_add(va, vb);
unsafe { v128_store(out.as_mut_ptr().add(base) as *mut v128, r) };
}
for i in (chunks * 4)..n {
out[i] = a[i] + b[i];
}
out
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn mul_simd(a: &[f32], b: &[f32]) -> Vec<f32> {
let n = a.len();
let chunks = n / 4;
let mut out = vec![0.0_f32; n];
for i in 0..chunks {
let base = i * 4;
let va = unsafe { v128_load(a.as_ptr().add(base) as *const v128) };
let vb = unsafe { v128_load(b.as_ptr().add(base) as *const v128) };
let r = f32x4_mul(va, vb);
unsafe { v128_store(out.as_mut_ptr().add(base) as *mut v128, r) };
}
for i in (chunks * 4)..n {
out[i] = a[i] * b[i];
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dot_product_basic() {
let a = [1.0_f32, 2.0, 3.0, 4.0];
let b = [1.0_f32, 1.0, 1.0, 1.0];
let result = simd_dot_product_f32(&a, &b).expect("dot product ok");
assert!((result - 10.0).abs() < 1e-6, "expected 10, got {result}");
}
#[test]
fn test_dot_product_non_multiple_of_4() {
let a = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
let b = [2.0_f32, 2.0, 2.0, 2.0, 2.0];
let result = simd_dot_product_f32(&a, &b).expect("dot product ok");
assert!((result - 30.0).abs() < 1e-5);
}
#[test]
fn test_dot_product_length_mismatch() {
let result = simd_dot_product_f32(&[1.0], &[1.0, 2.0]);
assert!(result.is_err());
}
#[test]
fn test_matmul_4x4_identity() {
let identity: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
];
let a: Vec<f32> = (1..=16).map(|x| x as f32).collect();
let result = simd_matrix_multiply_f32(&a, &identity, 4).expect("matmul ok");
for (r, e) in result.iter().zip(a.iter()) {
assert!((r - e).abs() < 1e-5, "expected {e}, got {r}");
}
}
#[test]
fn test_matmul_bad_size() {
let a = vec![1.0_f32; 9]; let b = vec![1.0_f32; 9];
let result = simd_matrix_multiply_f32(&a, &b, 3);
assert!(result.is_ok());
}
#[test]
fn test_matmul_zero_n() {
let result = simd_matrix_multiply_f32(&[], &[], 0);
assert!(result.is_err());
}
#[test]
fn test_softmax_sums_to_one() {
let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = simd_softmax_f32(&input).expect("softmax ok");
let total: f32 = result.iter().sum();
assert!((total - 1.0).abs() < 1e-6, "sum = {total}");
for &v in &result {
assert!(v >= 0.0 && v <= 1.0);
}
}
#[test]
fn test_softmax_empty() {
assert!(simd_softmax_f32(&[]).is_err());
}
#[test]
fn test_relu_basic() {
let input = [-3.0_f32, -1.0, 0.0, 1.0, 3.0];
let result = simd_relu_f32(&input);
let expected = [0.0_f32, 0.0, 0.0, 1.0, 3.0];
for (r, e) in result.iter().zip(expected.iter()) {
assert!((r - e).abs() < 1e-7, "expected {e}, got {r}");
}
}
#[test]
fn test_add_f32() {
let a = [1.0_f32, 2.0, 3.0];
let b = [4.0_f32, 5.0, 6.0];
let result = simd_add_f32(&a, &b).expect("add ok");
assert_eq!(result, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_mul_f32() {
let a = [2.0_f32, 3.0, 4.0];
let b = [5.0_f32, 6.0, 7.0];
let result = simd_mul_f32(&a, &b).expect("mul ok");
assert_eq!(result, vec![10.0, 18.0, 28.0]);
}
#[test]
fn test_l2_norm() {
let v = [3.0_f32, 4.0]; assert!((simd_l2_norm_f32(&v) - 5.0).abs() < 1e-5);
}
#[test]
fn test_sigmoid() {
let v = [0.0_f32];
let result = simd_sigmoid_f32(&v);
assert!((result[0] - 0.5).abs() < 1e-6);
}
}