pub fn gemm_wgsl(tile_size: u32) -> String {
format!(
r#"
struct GemmParams {{
m: u32,
n: u32,
k: u32,
alpha: f32,
beta: f32,
}}
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> c: array<f32>;
@group(0) @binding(3) var<uniform> params: GemmParams;
@compute @workgroup_size({ts}, {ts})
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let row = gid.y;
let col = gid.x;
if (row >= params.m || col >= params.n) {{ return; }}
var acc: f32 = 0.0;
for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
acc += a[row * params.k + i] * b[i * params.n + col];
}}
let idx = row * params.n + col;
c[idx] = params.alpha * acc + params.beta * c[idx];
}}
"#,
ts = tile_size
)
}
pub fn elementwise_wgsl(op: &str) -> String {
let op_expr = match op {
"relu" => "max(x, 0.0)",
"sigmoid" => "1.0 / (1.0 + exp(-x))",
"tanh" => "tanh(x)",
"exp" => "exp(x)",
"log" => "log(x)",
"sqrt" => "sqrt(x)",
"abs" => "abs(x)",
"neg" => "-x",
_ => "x",
};
format!(
r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let i = gid.x;
if (i >= arrayLength(&input)) {{ return; }}
let x = input[i];
output[i] = {op};
}}
"#,
op = op_expr
)
}
pub fn reduction_wgsl(op: &str) -> String {
let (neutral, combine) = match op {
"max" => ("f32(-1e38)", "max(acc, val)"),
"min" => ("f32(1e38)", "min(acc, val)"),
_ => ("f32(0.0)", "acc + val"),
};
format!(
r#"
// Reduction params: total element count.
struct ReduceParams {{
n: u32,
}}
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
@group(0) @binding(2) var<uniform> params: ReduceParams;
var<workgroup> shared_data: array<f32, 256>;
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>,
@builtin(workgroup_id) wgid: vec3<u32>,
) {{
let tid = lid.x;
let global_idx = gid.x;
// Load or use neutral element when out of range.
if (global_idx < params.n) {{
shared_data[tid] = input[global_idx];
}} else {{
shared_data[tid] = {neutral};
}}
workgroupBarrier();
// Parallel tree reduction within the workgroup.
var stride: u32 = 128u;
loop {{
if (stride == 0u) {{ break; }}
if (tid < stride) {{
let acc = shared_data[tid];
let val = shared_data[tid + stride];
shared_data[tid] = {combine};
}}
workgroupBarrier();
stride = stride >> 1u;
}}
// Thread 0 writes the workgroup result to the partial-sums buffer.
if (tid == 0u) {{
partial_sums[wgid.x] = shared_data[0];
}}
}}
"#,
neutral = neutral,
combine = combine,
)
}
pub fn reduction_final_wgsl(op: &str) -> String {
let (neutral, combine) = match op {
"max" => ("f32(-1e38)", "max(acc, val)"),
"min" => ("f32(1e38)", "min(acc, val)"),
_ => ("f32(0.0)", "acc + val"),
};
format!(
r#"
struct FinalReduceParams {{
num_groups: u32,
}}
@group(0) @binding(0) var<storage, read> partial_sums: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> params: FinalReduceParams;
var<workgroup> shared_data: array<f32, 256>;
@compute @workgroup_size(256)
fn main(
@builtin(local_invocation_id) lid: vec3<u32>,
) {{
let tid = lid.x;
if (tid < params.num_groups) {{
shared_data[tid] = partial_sums[tid];
}} else {{
shared_data[tid] = {neutral};
}}
workgroupBarrier();
var stride: u32 = 128u;
loop {{
if (stride == 0u) {{ break; }}
if (tid < stride) {{
let acc = shared_data[tid];
let val = shared_data[tid + stride];
shared_data[tid] = {combine};
}}
workgroupBarrier();
stride = stride >> 1u;
}}
if (tid == 0u) {{
output[0] = shared_data[0];
}}
}}
"#,
neutral = neutral,
combine = combine,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wgsl_gemm_contains_workgroup() {
let src = gemm_wgsl(16);
assert!(src.contains("@compute @workgroup_size(16, 16)"));
assert!(src.contains("GemmParams"));
assert!(src.contains("alpha"));
assert!(src.contains("beta"));
}
#[test]
fn wgsl_gemm_tile_size_embedded() {
let src8 = gemm_wgsl(8);
assert!(src8.contains("@workgroup_size(8, 8)"));
let src32 = gemm_wgsl(32);
assert!(src32.contains("@workgroup_size(32, 32)"));
}
#[test]
fn wgsl_elementwise_relu_contains_max() {
let src = elementwise_wgsl("relu");
assert!(src.contains("max(x, 0.0)"));
}
#[test]
fn wgsl_elementwise_all_ops() {
assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
assert!(elementwise_wgsl("exp").contains("exp(x)"));
assert!(elementwise_wgsl("log").contains("log(x)"));
assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
assert!(elementwise_wgsl("abs").contains("abs(x)"));
assert!(elementwise_wgsl("neg").contains("-x"));
assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
}
#[test]
fn wgsl_reduction_sum_contains_addition() {
let src = reduction_wgsl("sum");
assert!(src.contains("acc + val"));
assert!(src.contains("workgroupBarrier"));
}
#[test]
fn wgsl_reduction_max_uses_max_fn() {
let src = reduction_wgsl("max");
assert!(src.contains("max(acc, val)"));
}
#[test]
fn wgsl_reduction_min_uses_min_fn() {
let src = reduction_wgsl("min");
assert!(src.contains("min(acc, val)"));
}
#[test]
fn wgsl_reduction_mean_same_as_sum() {
let sum_src = reduction_wgsl("sum");
let mean_src = reduction_wgsl("mean");
assert_eq!(sum_src, mean_src);
}
#[test]
fn wgsl_reduction_final_sum() {
let src = reduction_final_wgsl("sum");
assert!(src.contains("num_groups"));
assert!(src.contains("output[0]"));
}
}