pub fn gemm_wgsl(tile_size: u32) -> String {
format!(
r#"
struct GemmParams {{
m: u32,
n: u32,
k: u32,
alpha: f32,
beta: f32,
}}
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> c: array<f32>;
@group(0) @binding(3) var<uniform> params: GemmParams;
@compute @workgroup_size({ts}, {ts})
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let row = gid.y;
let col = gid.x;
if (row >= params.m || col >= params.n) {{ return; }}
var acc: f32 = 0.0;
for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
acc += a[row * params.k + i] * b[i * params.n + col];
}}
let idx = row * params.n + col;
c[idx] = params.alpha * acc + params.beta * c[idx];
}}
"#,
ts = tile_size
)
}
pub fn batched_gemm_wgsl(tile_size: u32) -> String {
format!(
r#"
struct BatchedGemmParams {{
m: u32,
n: u32,
k: u32,
alpha: f32,
beta: f32,
batch_count: u32,
stride_a: u32,
stride_b: u32,
stride_c: u32,
}}
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> c: array<f32>;
@group(0) @binding(3) var<uniform> params: BatchedGemmParams;
@compute @workgroup_size({ts}, {ts})
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let row = gid.y;
let col = gid.x;
let batch_index = gid.z;
if (batch_index >= params.batch_count || row >= params.m || col >= params.n) {{ return; }}
let a_offset = batch_index * params.stride_a;
let b_offset = batch_index * params.stride_b;
let c_offset = batch_index * params.stride_c;
var acc: f32 = 0.0;
for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
acc += a[a_offset + row * params.k + i] * b[b_offset + i * params.n + col];
}}
let idx = c_offset + row * params.n + col;
c[idx] = params.alpha * acc + params.beta * c[idx];
}}
"#,
ts = tile_size
)
}
pub fn gemm_wgsl_f16(tile_size: u32) -> String {
format!(
r#"
enable f16;
struct GemmParams {{
m: u32,
n: u32,
k: u32,
alpha: f32,
beta: f32,
}}
@group(0) @binding(0) var<storage, read> a: array<f16>;
@group(0) @binding(1) var<storage, read> b: array<f16>;
@group(0) @binding(2) var<storage, read_write> c: array<f16>;
@group(0) @binding(3) var<uniform> params: GemmParams;
@compute @workgroup_size({ts}, {ts})
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let row = gid.y;
let col = gid.x;
if (row >= params.m || col >= params.n) {{ return; }}
var acc: f32 = 0.0;
for (var i: u32 = 0u; i < params.k; i = i + 1u) {{
acc += f32(a[row * params.k + i]) * f32(b[i * params.n + col]);
}}
let idx = row * params.n + col;
let prev = f32(c[idx]);
c[idx] = f16(params.alpha * acc + params.beta * prev);
}}
"#,
ts = tile_size
)
}
pub fn elementwise_wgsl(op: &str) -> String {
let op_expr = match op {
"relu" => "max(x, 0.0)",
"sigmoid" => "1.0 / (1.0 + exp(-x))",
"tanh" => "tanh(x)",
"exp" => "exp(x)",
"log" => "log(x)",
"sqrt" => "sqrt(x)",
"abs" => "abs(x)",
"neg" => "-x",
_ => "x",
};
format!(
r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let i = gid.x;
if (i >= arrayLength(&input)) {{ return; }}
let x = input[i];
output[i] = {op};
}}
"#,
op = op_expr
)
}
pub fn binary_wgsl(op: &str) -> String {
let op_expr = match op {
"add" => "a + b",
"sub" => "a - b",
"mul" => "a * b",
"div" => "a / b",
"max" => "max(a, b)",
"min" => "min(a, b)",
"pow" => "pow(a, b)",
_ => "a",
};
format!(
r#"
@group(0) @binding(0) var<storage, read> lhs: array<f32>;
@group(0) @binding(1) var<storage, read> rhs: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let i = gid.x;
if (i >= arrayLength(&lhs)) {{ return; }}
let a = lhs[i];
let b = rhs[i];
output[i] = {op};
}}
"#,
op = op_expr
)
}
pub fn reduction_wgsl(op: &str) -> String {
let (neutral, combine) = match op {
"max" => ("f32(-1e38)", "max(acc, val)"),
"min" => ("f32(1e38)", "min(acc, val)"),
_ => ("f32(0.0)", "acc + val"),
};
format!(
r#"
// Reduction params: total element count.
struct ReduceParams {{
n: u32,
}}
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> partial_sums: array<f32>;
@group(0) @binding(2) var<uniform> params: ReduceParams;
var<workgroup> shared_data: array<f32, 256>;
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>,
@builtin(workgroup_id) wgid: vec3<u32>,
) {{
let tid = lid.x;
let global_idx = gid.x;
// Load or use neutral element when out of range.
if (global_idx < params.n) {{
shared_data[tid] = input[global_idx];
}} else {{
shared_data[tid] = {neutral};
}}
workgroupBarrier();
// Parallel tree reduction within the workgroup.
var stride: u32 = 128u;
loop {{
if (stride == 0u) {{ break; }}
if (tid < stride) {{
let acc = shared_data[tid];
let val = shared_data[tid + stride];
shared_data[tid] = {combine};
}}
workgroupBarrier();
stride = stride >> 1u;
}}
// Thread 0 writes the workgroup result to the partial-sums buffer.
if (tid == 0u) {{
partial_sums[wgid.x] = shared_data[0];
}}
}}
"#,
neutral = neutral,
combine = combine,
)
}
#[allow(clippy::too_many_arguments)]
pub fn conv2d_wgsl(
n: u32,
c_in: u32,
h_in: u32,
w_in: u32,
k_out: u32,
fh: u32,
fw: u32,
oh: u32,
ow: u32,
stride_h: u32,
stride_w: u32,
pad_h: u32,
pad_w: u32,
) -> String {
format!(
r#"
// Conv2D NCHW — generated by oxicuda-webgpu
// input : [{n}, {c_in}, {h_in}, {w_in}]
// filter: [{k_out}, {c_in}, {fh}, {fw}]
// output: [{n}, {k_out}, {oh}, {ow}]
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> filter: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
// gid.x = output x (ox mapped across batches*k_out*oh)
// We flatten (batch, k, oy) into gid.y and ox into gid.x
let ox = gid.x;
let linear_y = gid.y;
let batch_k_oh = {n}u * {k_out}u * {oh}u;
if (ox >= {ow}u || linear_y >= batch_k_oh) {{ return; }}
let b = linear_y / ({k_out}u * {oh}u);
let rem = linear_y % ({k_out}u * {oh}u);
let kf = rem / {oh}u;
let oy = rem % {oh}u;
var acc: f32 = 0.0;
for (var ci: u32 = 0u; ci < {c_in}u; ci = ci + 1u) {{
for (var fy: u32 = 0u; fy < {fh}u; fy = fy + 1u) {{
for (var fx: u32 = 0u; fx < {fw}u; fx = fx + 1u) {{
let iy_raw = i32(oy * {stride_h}u + fy) - i32({pad_h}u);
let ix_raw = i32(ox * {stride_w}u + fx) - i32({pad_w}u);
if (iy_raw >= 0 && iy_raw < i32({h_in}u) && ix_raw >= 0 && ix_raw < i32({w_in}u)) {{
let iy = u32(iy_raw);
let ix = u32(ix_raw);
let in_idx = ((b * {c_in}u + ci) * {h_in}u + iy) * {w_in}u + ix;
let f_idx = ((kf * {c_in}u + ci) * {fh}u + fy) * {fw}u + fx;
acc += input[in_idx] * filter[f_idx];
}}
}}
}}
}}
let o_idx = ((b * {k_out}u + kf) * {oh}u + oy) * {ow}u + ox;
output[o_idx] = acc;
}}
"#,
n = n,
c_in = c_in,
h_in = h_in,
w_in = w_in,
k_out = k_out,
fh = fh,
fw = fw,
oh = oh,
ow = ow,
stride_h = stride_h,
stride_w = stride_w,
pad_h = pad_h,
pad_w = pad_w,
)
}
pub fn attention_wgsl(
batch_heads: u32,
seq_q: u32,
seq_kv: u32,
head_dim: u32,
scale: f32,
causal: bool,
) -> String {
let causal_check = if causal {
"if (sk > sq) { score = f32(-1e38); } else {"
} else {
"{"
};
format!(
r#"
// Scaled dot-product attention — generated by oxicuda-webgpu
// Q, K, V : [{batch_heads}, seq, {head_dim}]
// O : [{batch_heads}, {seq_q}, {head_dim}]
// scale : {scale}
// causal : {causal}
@group(0) @binding(0) var<storage, read> q_buf: array<f32>;
@group(0) @binding(1) var<storage, read> k_buf: array<f32>;
@group(0) @binding(2) var<storage, read> v_buf: array<f32>;
@group(0) @binding(3) var<storage, read_write> o_buf: array<f32>;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let linear = gid.x;
let total = {batch_heads}u * {seq_q}u;
if (linear >= total) {{ return; }}
let bh = linear / {seq_q}u;
let sq = linear % {seq_q}u;
let q_base = (bh * {seq_q}u + sq) * {head_dim}u;
// Pass 1: find max score for numerical stability
var max_score: f32 = f32(-1e38);
for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
var score: f32 = 0.0;
{causal_check}
let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
score += q_buf[q_base + d] * k_buf[k_base + d];
}}
score *= f32({scale});
}}
if (score > max_score) {{ max_score = score; }}
}}
// Pass 2: compute exp(score - max), accumulate weighted V
var sum_exp: f32 = 0.0;
for (var sk: u32 = 0u; sk < {seq_kv}u; sk = sk + 1u) {{
var score: f32 = 0.0;
{causal_check}
let k_base = (bh * {seq_kv}u + sk) * {head_dim}u;
for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
score += q_buf[q_base + d] * k_buf[k_base + d];
}}
score *= f32({scale});
}}
let w = exp(score - max_score);
sum_exp += w;
let v_base = (bh * {seq_kv}u + sk) * {head_dim}u;
let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
// Accumulate in-place (we normalise after the loop).
o_buf[o_base + d] += w * v_buf[v_base + d];
}}
}}
// Pass 3: normalise
if (sum_exp > 0.0) {{
let o_base = (bh * {seq_q}u + sq) * {head_dim}u;
for (var d: u32 = 0u; d < {head_dim}u; d = d + 1u) {{
o_buf[o_base + d] /= sum_exp;
}}
}}
}}
"#,
batch_heads = batch_heads,
seq_q = seq_q,
seq_kv = seq_kv,
head_dim = head_dim,
scale = scale,
causal = causal,
causal_check = causal_check,
)
}
pub fn reduction_nd_wgsl(op: &str) -> String {
let (neutral, combine, combine_alias) = match op {
"max" => ("f32(-1e38)", "max(acc, val)", "max(acc2, val)"),
"min" => ("f32(1e38)", "min(acc, val)", "min(acc2, val)"),
_ => ("f32(0.0)", "acc + val", "acc2 + val"),
};
let final_expr = if op == "mean" {
"shared_data[0] / f32(params.dk)"
} else {
"shared_data[0]"
};
format!(
r#"
struct ReduceNdParams {{
outer: u32,
dk: u32,
inner: u32,
outer_stride: u32,
dk_stride: u32,
inner_stride: u32,
grid_x: u32,
_pad: u32,
}}
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> params: ReduceNdParams;
var<workgroup> shared_data: array<f32, 256>;
@compute @workgroup_size(256)
fn main(
@builtin(local_invocation_id) lid: vec3<u32>,
@builtin(workgroup_id) wgid: vec3<u32>,
) {{
let tid = lid.x;
let total = params.outer * params.inner;
// Decode 2-D workgroup id back to a linear output slot.
let slot = wgid.y * params.grid_x + wgid.x;
if (slot >= total) {{ return; }}
let o = slot / params.inner;
let j = slot % params.inner;
let base = o * params.outer_stride + j * params.inner_stride;
// Strided per-thread reduction across the dk axis.
var acc: f32 = {neutral};
var i: u32 = tid;
loop {{
if (i >= params.dk) {{ break; }}
let val = input[base + i * params.dk_stride];
acc = {combine};
i = i + 256u;
}}
shared_data[tid] = acc;
workgroupBarrier();
// Tree reduction within the workgroup.
var stride: u32 = 128u;
loop {{
if (stride == 0u) {{ break; }}
if (tid < stride) {{
let acc2 = shared_data[tid];
let val = shared_data[tid + stride];
shared_data[tid] = {combine_alias};
}}
workgroupBarrier();
stride = stride >> 1u;
}}
if (tid == 0u) {{
output[slot] = {final_expr};
}}
}}
"#,
neutral = neutral,
combine = combine,
combine_alias = combine_alias,
final_expr = final_expr,
)
}
pub fn reduction_final_wgsl(op: &str) -> String {
let (neutral, combine) = match op {
"max" => ("f32(-1e38)", "max(acc, val)"),
"min" => ("f32(1e38)", "min(acc, val)"),
_ => ("f32(0.0)", "acc + val"),
};
format!(
r#"
struct FinalReduceParams {{
num_groups: u32,
}}
@group(0) @binding(0) var<storage, read> partial_sums: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> params: FinalReduceParams;
var<workgroup> shared_data: array<f32, 256>;
@compute @workgroup_size(256)
fn main(
@builtin(local_invocation_id) lid: vec3<u32>,
) {{
let tid = lid.x;
if (tid < params.num_groups) {{
shared_data[tid] = partial_sums[tid];
}} else {{
shared_data[tid] = {neutral};
}}
workgroupBarrier();
var stride: u32 = 128u;
loop {{
if (stride == 0u) {{ break; }}
if (tid < stride) {{
let acc = shared_data[tid];
let val = shared_data[tid + stride];
shared_data[tid] = {combine};
}}
workgroupBarrier();
stride = stride >> 1u;
}}
if (tid == 0u) {{
output[0] = shared_data[0];
}}
}}
"#,
neutral = neutral,
combine = combine,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wgsl_gemm_contains_workgroup() {
let src = gemm_wgsl(16);
assert!(src.contains("@compute @workgroup_size(16, 16)"));
assert!(src.contains("GemmParams"));
assert!(src.contains("alpha"));
assert!(src.contains("beta"));
}
#[test]
fn wgsl_gemm_tile_size_embedded() {
let src8 = gemm_wgsl(8);
assert!(src8.contains("@workgroup_size(8, 8)"));
let src32 = gemm_wgsl(32);
assert!(src32.contains("@workgroup_size(32, 32)"));
}
#[test]
fn wgsl_elementwise_relu_contains_max() {
let src = elementwise_wgsl("relu");
assert!(src.contains("max(x, 0.0)"));
}
#[test]
fn wgsl_elementwise_all_ops() {
assert!(elementwise_wgsl("sigmoid").contains("exp(-x)"));
assert!(elementwise_wgsl("tanh").contains("tanh(x)"));
assert!(elementwise_wgsl("exp").contains("exp(x)"));
assert!(elementwise_wgsl("log").contains("log(x)"));
assert!(elementwise_wgsl("sqrt").contains("sqrt(x)"));
assert!(elementwise_wgsl("abs").contains("abs(x)"));
assert!(elementwise_wgsl("neg").contains("-x"));
assert!(elementwise_wgsl("identity_op").contains("output[i] = x;"));
}
#[test]
fn wgsl_reduction_sum_contains_addition() {
let src = reduction_wgsl("sum");
assert!(src.contains("acc + val"));
assert!(src.contains("workgroupBarrier"));
}
#[test]
fn wgsl_reduction_max_uses_max_fn() {
let src = reduction_wgsl("max");
assert!(src.contains("max(acc, val)"));
}
#[test]
fn wgsl_reduction_min_uses_min_fn() {
let src = reduction_wgsl("min");
assert!(src.contains("min(acc, val)"));
}
#[test]
fn wgsl_reduction_mean_same_as_sum() {
let sum_src = reduction_wgsl("sum");
let mean_src = reduction_wgsl("mean");
assert_eq!(sum_src, mean_src);
}
#[test]
fn wgsl_reduction_final_sum() {
let src = reduction_final_wgsl("sum");
assert!(src.contains("num_groups"));
assert!(src.contains("output[0]"));
}
#[test]
fn wgsl_reduction_nd_sum_contains_addition() {
let src = reduction_nd_wgsl("sum");
assert!(src.contains("acc + val"));
assert!(src.contains("acc2 + val"));
assert!(src.contains("workgroupBarrier"));
assert!(src.contains("ReduceNdParams"));
}
#[test]
fn wgsl_reduction_nd_max_uses_max_fn() {
let src = reduction_nd_wgsl("max");
assert!(src.contains("max(acc, val)"));
assert!(src.contains("max(acc2, val)"));
}
#[test]
fn wgsl_reduction_nd_min_uses_min_fn() {
let src = reduction_nd_wgsl("min");
assert!(src.contains("min(acc, val)"));
assert!(src.contains("min(acc2, val)"));
}
#[test]
fn wgsl_reduction_nd_mean_divides_by_dk() {
let src = reduction_nd_wgsl("mean");
assert!(src.contains("shared_data[0] / f32(params.dk)"));
assert!(src.contains("acc + val"));
}
#[test]
fn wgsl_reduction_nd_sum_does_not_divide() {
let src = reduction_nd_wgsl("sum");
assert!(!src.contains("/ f32(params.dk)"));
}
#[test]
fn wgsl_reduction_nd_decodes_2d_dispatch() {
let src = reduction_nd_wgsl("sum");
assert!(src.contains("wgid.y * params.grid_x + wgid.x"));
}
#[test]
fn wgsl_reduction_nd_uses_strided_loop() {
let src = reduction_nd_wgsl("sum");
assert!(src.contains("i = i + 256u"));
}
#[test]
fn wgsl_binary_add() {
let src = binary_wgsl("add");
assert!(src.contains("a + b"));
assert!(src.contains("lhs"));
assert!(src.contains("rhs"));
}
#[test]
fn wgsl_binary_all_ops() {
assert!(binary_wgsl("sub").contains("a - b"));
assert!(binary_wgsl("mul").contains("a * b"));
assert!(binary_wgsl("div").contains("a / b"));
assert!(binary_wgsl("max").contains("max(a, b)"));
assert!(binary_wgsl("min").contains("min(a, b)"));
assert!(binary_wgsl("pow").contains("pow(a, b)"));
assert!(binary_wgsl("unknown_op").contains("output[i] = a;"));
}
#[test]
fn wgsl_binary_workgroup_size() {
let src = binary_wgsl("add");
assert!(src.contains("@workgroup_size(256)"));
}
#[test]
fn wgsl_conv2d_contains_workgroup() {
let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
assert!(src.contains("@compute @workgroup_size(8, 8)"));
}
#[test]
fn wgsl_conv2d_contains_storage_bindings() {
let src = conv2d_wgsl(1, 3, 32, 32, 16, 3, 3, 30, 30, 1, 1, 0, 0);
assert!(src.contains("var<storage, read> input:"));
assert!(src.contains("var<storage, read> filter:"));
assert!(src.contains("var<storage, read_write> output:"));
}
#[test]
fn wgsl_conv2d_embeds_dimensions() {
let src = conv2d_wgsl(2, 8, 64, 64, 32, 5, 5, 60, 60, 1, 1, 0, 0);
assert!(src.contains("8u")); assert!(src.contains("64u")); assert!(src.contains("32u")); assert!(src.contains("5u")); assert!(src.contains("60u")); }
#[test]
fn wgsl_conv2d_has_padding_check() {
let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 8, 8, 1, 1, 1, 1);
assert!(src.contains("iy_raw >= 0"));
assert!(src.contains("ix_raw >= 0"));
}
#[test]
fn wgsl_conv2d_has_stride() {
let src = conv2d_wgsl(1, 1, 8, 8, 1, 3, 3, 3, 3, 2, 2, 0, 0);
assert!(src.contains("2u")); }
#[test]
fn wgsl_attention_contains_workgroup() {
let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
assert!(src.contains("@compute @workgroup_size(64)"));
}
#[test]
fn wgsl_attention_contains_storage_bindings() {
let src = attention_wgsl(4, 8, 8, 64, 0.125, false);
assert!(src.contains("var<storage, read> q_buf:"));
assert!(src.contains("var<storage, read> k_buf:"));
assert!(src.contains("var<storage, read> v_buf:"));
assert!(src.contains("var<storage, read_write> o_buf:"));
}
#[test]
fn wgsl_attention_stable_softmax() {
let src = attention_wgsl(1, 4, 4, 32, 0.25, false);
assert!(src.contains("max_score"));
assert!(src.contains("exp(score - max_score)"));
assert!(src.contains("sum_exp"));
}
#[test]
fn wgsl_attention_causal_mask() {
let src_causal = attention_wgsl(1, 4, 4, 32, 0.25, true);
assert!(src_causal.contains("sk > sq"));
let src_non_causal = attention_wgsl(1, 4, 4, 32, 0.25, false);
assert!(!src_non_causal.contains("sk > sq"));
}
#[test]
fn wgsl_attention_embeds_scale() {
let src = attention_wgsl(2, 16, 16, 64, 0.125, false);
assert!(src.contains("0.125"));
}
#[test]
fn wgsl_batched_gemm_contains_batch_params() {
let src = batched_gemm_wgsl(16);
assert!(src.contains("batch_count"));
assert!(src.contains("stride_a"));
assert!(src.contains("stride_b"));
assert!(src.contains("stride_c"));
}
#[test]
fn wgsl_batched_gemm_contains_workgroup() {
let src = batched_gemm_wgsl(16);
assert!(src.contains("@compute @workgroup_size(16, 16)"));
assert!(src.contains("BatchedGemmParams"));
}
#[test]
fn wgsl_batched_gemm_uses_batch_index() {
let src = batched_gemm_wgsl(8);
assert!(src.contains("batch_index"));
assert!(src.contains("gid.z"));
}
#[test]
fn wgsl_batched_gemm_tile_size_embedded() {
let src8 = batched_gemm_wgsl(8);
assert!(src8.contains("@workgroup_size(8, 8)"));
let src32 = batched_gemm_wgsl(32);
assert!(src32.contains("@workgroup_size(32, 32)"));
}
#[test]
fn wgsl_gemm_f16_enables_extension() {
let src = gemm_wgsl_f16(16);
assert!(src.contains("enable f16;"));
}
#[test]
fn wgsl_gemm_f16_uses_f16_storage() {
let src = gemm_wgsl_f16(16);
assert!(src.contains("array<f16>"));
}
#[test]
fn wgsl_gemm_f16_accumulates_in_f32() {
let src = gemm_wgsl_f16(16);
assert!(src.contains("var acc: f32 = 0.0;"));
assert!(src.contains("f32(a["));
assert!(src.contains("f32(b["));
}
#[test]
fn wgsl_gemm_f16_contains_workgroup() {
let src = gemm_wgsl_f16(8);
assert!(src.contains("@compute @workgroup_size(8, 8)"));
assert!(src.contains("GemmParams"));
}
#[test]
fn wgsl_attention_embeds_dimensions() {
let src = attention_wgsl(8, 32, 32, 128, 0.088, true);
assert!(src.contains("128u")); assert!(src.contains("32u")); assert!(src.contains("8u")); }
}