pub(super) const ELEMENTWISE_ADD_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> a : array<f32>;
@group(0) @binding(1) var<storage, read> b : array<f32>;
@group(0) @binding(2) var<storage, read_write> result : array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= arrayLength(&result)) { return; }
result[idx] = a[idx] + b[idx];
}
"#;
pub(super) const ELEMENTWISE_SUB_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> a : array<f32>;
@group(0) @binding(1) var<storage, read> b : array<f32>;
@group(0) @binding(2) var<storage, read_write> result : array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= arrayLength(&result)) { return; }
result[idx] = a[idx] - b[idx];
}
"#;
pub(super) const ELEMENTWISE_MUL_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> a : array<f32>;
@group(0) @binding(1) var<storage, read> b : array<f32>;
@group(0) @binding(2) var<storage, read_write> result : array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= arrayLength(&result)) { return; }
result[idx] = a[idx] * b[idx];
}
"#;
pub(super) const SCALAR_MUL_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> a : array<f32>;
@group(0) @binding(1) var<storage, read_write> result : array<f32>;
struct Uniforms {
scalar : f32,
n : u32,
};
@group(0) @binding(2) var<uniform> uniforms : Uniforms;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= uniforms.n) { return; }
result[idx] = a[idx] * uniforms.scalar;
}
"#;
pub(super) const MATMUL_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> a_mat : array<f32>;
@group(0) @binding(1) var<storage, read> b_mat : array<f32>;
@group(0) @binding(2) var<storage, read_write> c_mat : array<f32>;
struct Uniforms {
M : u32,
N : u32,
K : u32,
_pad : u32,
};
@group(0) @binding(3) var<uniform> uniforms : Uniforms;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let col = gid.x;
let row = gid.y;
if (row >= uniforms.M || col >= uniforms.N) { return; }
var acc : f32 = 0.0;
for (var k : u32 = 0u; k < uniforms.K; k++) {
acc += a_mat[row * uniforms.K + k] * b_mat[k * uniforms.N + col];
}
c_mat[row * uniforms.N + col] = acc;
}
"#;
pub(super) const SUM_REDUCE_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input : array<f32>;
@group(0) @binding(1) var<storage, read_write> partial : array<f32>;
struct Uniforms {
n : u32,
};
@group(0) @binding(2) var<uniform> uniforms : Uniforms;
var<workgroup> wg_data : array<f32, 256>;
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) gid : vec3<u32>,
@builtin(local_invocation_id) lid : vec3<u32>,
@builtin(workgroup_id) wgid : vec3<u32>
) {
let idx = gid.x;
let local_idx = lid.x;
if (idx < uniforms.n) {
wg_data[local_idx] = input[idx];
} else {
wg_data[local_idx] = 0.0;
}
workgroupBarrier();
var stride : u32 = 128u;
loop {
if (stride == 0u) { break; }
if (local_idx < stride) {
wg_data[local_idx] += wg_data[local_idx + stride];
}
workgroupBarrier();
stride = stride / 2u;
}
if (local_idx == 0u) {
partial[wgid.x] = wg_data[0];
}
}
"#;
pub(super) const TRANSPOSE_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input : array<f32>;
@group(0) @binding(1) var<storage, read_write> output : array<f32>;
struct Uniforms {
rows : u32,
cols : u32,
};
@group(0) @binding(2) var<uniform> uniforms : Uniforms;
// +1 pad avoids bank conflicts for 16-wide tiles
var<workgroup> tile : array<f32, 272>; // 16 * 17
@compute @workgroup_size(16, 16)
fn main(
@builtin(global_invocation_id) gid : vec3<u32>,
@builtin(local_invocation_id) lid : vec3<u32>,
@builtin(workgroup_id) wgid : vec3<u32>
) {
let in_col = wgid.x * 16u + lid.x;
let in_row = wgid.y * 16u + lid.y;
if (in_row < uniforms.rows && in_col < uniforms.cols) {
tile[lid.y * 17u + lid.x] = input[in_row * uniforms.cols + in_col];
}
workgroupBarrier();
let out_col = wgid.y * 16u + lid.x;
let out_row = wgid.x * 16u + lid.y;
if (out_row < uniforms.cols && out_col < uniforms.rows) {
output[out_row * uniforms.rows + out_col] = tile[lid.x * 17u + lid.y];
}
}
"#;
pub(super) const CONCAT_AXISN_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> a_buf : array<f32>;
@group(0) @binding(1) var<storage, read> b_buf : array<f32>;
@group(0) @binding(2) var<storage, read_write> result : array<f32>;
struct ConcatUniforms {
axis : u32,
dim_a : u32,
ndim : u32,
_pad : u32,
};
@group(0) @binding(3) var<uniform> uniforms : ConcatUniforms;
@group(0) @binding(4) var<storage, read> out_shape : array<u32>;
@group(0) @binding(5) var<storage, read> out_strides : array<u32>;
@group(0) @binding(6) var<storage, read> a_strides : array<u32>;
@group(0) @binding(7) var<storage, read> b_strides : array<u32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_idx = gid.x;
let n_out = arrayLength(&result);
if (out_idx >= n_out) { return; }
let ndim = uniforms.ndim;
let axis = uniforms.axis;
let dim_a = uniforms.dim_a;
// Decompose out_idx into multi-dimensional coordinates using out_strides
var remaining : u32 = out_idx;
var coords : array<u32, 8>;
for (var d : u32 = 0u; d < ndim; d++) {
let s = out_strides[d];
coords[d] = remaining / s;
remaining = remaining % s;
}
// Determine whether this coordinate comes from A or B
let ax_coord = coords[axis];
if (ax_coord < dim_a) {
// Read from A: compute flat index using a_strides
var a_idx : u32 = 0u;
for (var d : u32 = 0u; d < ndim; d++) {
a_idx += coords[d] * a_strides[d];
}
result[out_idx] = a_buf[a_idx];
} else {
// Read from B: shift axis coordinate by -dim_a
coords[axis] = ax_coord - dim_a;
var b_idx : u32 = 0u;
for (var d : u32 = 0u; d < ndim; d++) {
b_idx += coords[d] * b_strides[d];
}
result[out_idx] = b_buf[b_idx];
}
}
"#;
pub(super) const REDUCE_SUM_AXIS_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> input : array<f32>;
@group(0) @binding(1) var<storage, read_write> result : array<f32>;
struct ReduceUniforms {
axis : u32,
axis_size : u32,
ndim : u32,
in_axis_stride : u32,
};
@group(0) @binding(2) var<uniform> uniforms : ReduceUniforms;
@group(0) @binding(3) var<storage, read> in_shape : array<u32>;
@group(0) @binding(4) var<storage, read> in_strides : array<u32>;
@group(0) @binding(5) var<storage, read> out_shape : array<u32>;
@group(0) @binding(6) var<storage, read> out_strides : array<u32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_idx = gid.x;
let n_out = arrayLength(&result);
if (out_idx >= n_out) { return; }
let ndim = uniforms.ndim;
let axis = uniforms.axis;
let axis_size = uniforms.axis_size;
let in_axis_stride = uniforms.in_axis_stride;
let out_ndim = ndim - 1u;
// Decompose out_idx into coordinates for the (ndim-1)-dim output
var remaining : u32 = out_idx;
var out_coords : array<u32, 8>;
for (var d : u32 = 0u; d < out_ndim; d++) {
let s = out_strides[d];
out_coords[d] = remaining / s;
remaining = remaining % s;
}
// Map out_coords back to input coordinates (insert 0 for the reduction axis)
// and compute base flat offset in input with axis index = 0
var base_in : u32 = 0u;
var od : u32 = 0u; // out dimension cursor
for (var d : u32 = 0u; d < ndim; d++) {
if (d == axis) {
// reduction axis: contribute 0 (will be summed over j below)
} else {
base_in += out_coords[od] * in_strides[d];
od += 1u;
}
}
// Sum over all elements along the reduction axis
var acc : f32 = 0.0;
for (var j : u32 = 0u; j < axis_size; j++) {
acc += input[base_in + j * in_axis_stride];
}
result[out_idx] = acc;
}
"#;