use crate::webgpu::backend::WebGpuContext;
use crate::webgpu::shader_gen::{ElementwiseOp, ReductionOp};
use crate::webgpu::types::{GpuBufferUsage, GpuError, WebGpuConfig, WebGpuResult};
pub struct WasmWebGpu {
ctx: WebGpuContext,
}
impl WasmWebGpu {
pub fn new() -> Self {
Self {
ctx: WebGpuContext::new(WebGpuConfig::default()),
}
}
pub fn with_config(config: WebGpuConfig) -> Self {
Self {
ctx: WebGpuContext::new(config),
}
}
pub fn js_matmul(
&mut self,
a: &[f32],
b: &[f32],
m: u32,
n: u32,
k: u32,
) -> WebGpuResult<Vec<f32>> {
let (m, n, k) = (m as usize, n as usize, k as usize);
let a_id = self
.ctx
.upload_buffer(a.to_vec(), GpuBufferUsage::Storage)?;
let b_id = self
.ctx
.upload_buffer(b.to_vec(), GpuBufferUsage::Storage)?;
let c_id = self.ctx.matmul(a_id, b_id, m, k, n)?;
self.ctx.download_buffer(c_id)
}
pub fn js_elementwise_relu(&mut self, data: &[f32]) -> WebGpuResult<Vec<f32>> {
let id = self
.ctx
.upload_buffer(data.to_vec(), GpuBufferUsage::Storage)?;
let out_id = self.ctx.elementwise(id, None, ElementwiseOp::Relu)?;
self.ctx.download_buffer(out_id)
}
pub fn js_elementwise_sigmoid(&mut self, data: &[f32]) -> WebGpuResult<Vec<f32>> {
let id = self
.ctx
.upload_buffer(data.to_vec(), GpuBufferUsage::Storage)?;
let out_id = self.ctx.elementwise(id, None, ElementwiseOp::Sigmoid)?;
self.ctx.download_buffer(out_id)
}
pub fn js_elementwise_exp(&mut self, data: &[f32]) -> WebGpuResult<Vec<f32>> {
let id = self
.ctx
.upload_buffer(data.to_vec(), GpuBufferUsage::Storage)?;
let out_id = self.ctx.elementwise(id, None, ElementwiseOp::Exp)?;
self.ctx.download_buffer(out_id)
}
pub fn js_elementwise_log(&mut self, data: &[f32]) -> WebGpuResult<Vec<f32>> {
let id = self
.ctx
.upload_buffer(data.to_vec(), GpuBufferUsage::Storage)?;
let out_id = self.ctx.elementwise(id, None, ElementwiseOp::Log)?;
self.ctx.download_buffer(out_id)
}
pub fn js_elementwise_add(&mut self, a: &[f32], b: &[f32]) -> WebGpuResult<Vec<f32>> {
if a.len() != b.len() {
return Err(GpuError::Execution(format!(
"add: length mismatch {} vs {}",
a.len(),
b.len()
)));
}
let a_id = self
.ctx
.upload_buffer(a.to_vec(), GpuBufferUsage::Storage)?;
let b_id = self
.ctx
.upload_buffer(b.to_vec(), GpuBufferUsage::Storage)?;
let out_id = self.ctx.elementwise(a_id, Some(b_id), ElementwiseOp::Add)?;
self.ctx.download_buffer(out_id)
}
pub fn js_reduction_sum(&mut self, data: &[f32]) -> WebGpuResult<f32> {
let id = self
.ctx
.upload_buffer(data.to_vec(), GpuBufferUsage::Storage)?;
self.ctx.reduce(id, ReductionOp::Sum)
}
pub fn js_reduction_max(&mut self, data: &[f32]) -> WebGpuResult<f32> {
let id = self
.ctx
.upload_buffer(data.to_vec(), GpuBufferUsage::Storage)?;
self.ctx.reduce(id, ReductionOp::Max)
}
pub fn js_reduction_min(&mut self, data: &[f32]) -> WebGpuResult<f32> {
let id = self
.ctx
.upload_buffer(data.to_vec(), GpuBufferUsage::Storage)?;
self.ctx.reduce(id, ReductionOp::Min)
}
}
impl Default for WasmWebGpu {
fn default() -> Self {
Self::new()
}
}
pub fn matmul_f32(a: &[f32], b: &[f32], m: u32, n: u32, k: u32) -> WebGpuResult<Vec<f32>> {
WasmWebGpu::new().js_matmul(a, b, m, n, k)
}
pub fn relu_f32(data: &[f32]) -> WebGpuResult<Vec<f32>> {
WasmWebGpu::new().js_elementwise_relu(data)
}
pub fn sigmoid_f32(data: &[f32]) -> WebGpuResult<Vec<f32>> {
WasmWebGpu::new().js_elementwise_sigmoid(data)
}
pub fn reduce_sum_f32(data: &[f32]) -> WebGpuResult<f32> {
WasmWebGpu::new().js_reduction_sum(data)
}
pub fn reduce_max_f32(data: &[f32]) -> WebGpuResult<f32> {
WasmWebGpu::new().js_reduction_max(data)
}
#[cfg(target_arch = "wasm32")]
mod wasm_export {
use super::*;
use wasm_bindgen::prelude::*;
#[wasm_bindgen(js_name = "gpu_matmul")]
pub fn wasm_matmul(a: &[f32], b: &[f32], m: u32, n: u32, k: u32) -> Result<Vec<f32>, JsValue> {
matmul_f32(a, b, m, n, k).map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = "gpu_relu")]
pub fn wasm_relu(data: &[f32]) -> Result<Vec<f32>, JsValue> {
relu_f32(data).map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = "gpu_sigmoid")]
pub fn wasm_sigmoid(data: &[f32]) -> Result<Vec<f32>, JsValue> {
sigmoid_f32(data).map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = "gpu_reduce_sum")]
pub fn wasm_reduce_sum(data: &[f32]) -> Result<f32, JsValue> {
reduce_sum_f32(data).map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = "gpu_reduce_max")]
pub fn wasm_reduce_max(data: &[f32]) -> Result<f32, JsValue> {
reduce_max_f32(data).map_err(|e| JsValue::from_str(&e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn gpu() -> WasmWebGpu {
WasmWebGpu::new()
}
#[test]
fn test_js_matmul_2x2() {
let mut g = gpu();
let a = [1.0_f32, 2.0, 3.0, 4.0];
let b = [5.0_f32, 6.0, 7.0, 8.0];
let c = g.js_matmul(&a, &b, 2, 2, 2).expect("matmul");
let expected = [19.0_f32, 22.0, 43.0, 50.0];
for (r, &e) in c.iter().zip(expected.iter()) {
assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
}
}
#[test]
fn test_matmul_f32_free_fn() {
let a = [1.0_f32, 0.0, 0.0, 1.0]; let b = [3.0_f32, 7.0, 2.0, 5.0];
let c = matmul_f32(&a, &b, 2, 2, 2).expect("matmul");
for (r, &e) in c.iter().zip(b.iter()) {
assert!((r - e).abs() < 1e-4, "identity matmul: {r} != {e}");
}
}
#[test]
fn test_js_relu_clips_negatives() {
let mut g = gpu();
let data = [-3.0_f32, -0.5, 0.0, 1.0, 4.0];
let out = g.js_elementwise_relu(&data).expect("relu");
assert_eq!(out, [0.0_f32, 0.0, 0.0, 1.0, 4.0]);
}
#[test]
fn test_relu_f32_free_fn() {
let out = relu_f32(&[-1.0_f32, 2.0, -3.0]).expect("relu");
assert_eq!(out, [0.0_f32, 2.0, 0.0]);
}
#[test]
fn test_js_sigmoid_in_range() {
let mut g = gpu();
let data: Vec<f32> = (-10..=10).map(|x| x as f32).collect();
let out = g.js_elementwise_sigmoid(&data).expect("sigmoid");
for &v in &out {
assert!(v > 0.0 && v < 1.0, "sigmoid out of (0,1): {v}");
}
}
#[test]
fn test_sigmoid_f32_free_fn() {
let out = sigmoid_f32(&[0.0_f32]).expect("sigmoid");
assert!((out[0] - 0.5).abs() < 1e-5, "sigmoid(0) should be 0.5");
}
#[test]
fn test_js_reduction_sum_equals_direct() {
let mut g = gpu();
let data: Vec<f32> = (1..=50).map(|x| x as f32).collect();
let expected: f32 = data.iter().sum();
let sum = g.js_reduction_sum(&data).expect("sum");
assert!((sum - expected).abs() < 1.0, "sum {sum} != {expected}");
}
#[test]
fn test_js_reduction_max_equals_direct() {
let mut g = gpu();
let data = vec![3.0_f32, 1.0, 4.0, 1.5, 9.0, 2.6];
let sum = g.js_reduction_max(&data).expect("max");
assert!((sum - 9.0).abs() < 1e-5, "max should be 9.0");
}
#[test]
fn test_reduce_sum_free_fn() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0];
let sum = reduce_sum_f32(&data).expect("sum");
assert!((sum - 10.0).abs() < 1e-5, "sum should be 10.0");
}
#[test]
fn test_reduce_max_free_fn() {
let data = vec![5.0_f32, 3.0, 8.0, 1.0];
let max = reduce_max_f32(&data).expect("max");
assert!((max - 8.0).abs() < 1e-5, "max should be 8.0");
}
#[test]
fn test_js_elementwise_add() {
let mut g = gpu();
let a = [1.0_f32, 2.0, 3.0];
let b = [4.0_f32, 5.0, 6.0];
let out = g.js_elementwise_add(&a, &b).expect("add");
assert_eq!(out, [5.0_f32, 7.0, 9.0]);
}
#[test]
fn test_js_elementwise_add_length_mismatch_fails() {
let mut g = gpu();
let result = g.js_elementwise_add(&[1.0_f32], &[1.0_f32, 2.0]);
assert!(result.is_err(), "length mismatch should be an error");
}
#[test]
fn test_js_elementwise_exp() {
let mut g = gpu();
let out = g.js_elementwise_exp(&[0.0_f32, 1.0]).expect("exp");
assert!((out[0] - 1.0).abs() < 1e-5, "exp(0)=1");
assert!((out[1] - std::f32::consts::E).abs() < 1e-4, "exp(1)=e");
}
#[test]
fn test_js_elementwise_log() {
let mut g = gpu();
let out = g
.js_elementwise_log(&[1.0_f32, std::f32::consts::E])
.expect("log");
assert!(out[0].abs() < 1e-5, "log(1)=0");
assert!((out[1] - 1.0).abs() < 1e-4, "log(e)=1");
}
#[test]
fn test_wasm_webgpu_default() {
let g = WasmWebGpu::default();
assert!(!g.ctx.is_gpu_available());
}
#[test]
fn test_with_config_custom_tile() {
let cfg = WebGpuConfig {
workgroup_size_x: 4,
..WebGpuConfig::default()
};
let mut g = WasmWebGpu::with_config(cfg);
let out = g
.js_matmul(&[3.0_f32], &[4.0_f32], 1, 1, 1)
.expect("matmul");
assert!((out[0] - 12.0).abs() < 1e-4, "1x1 matmul should be 12");
}
}