// Fused GEMM + bias + activation. F32 only.
// C = activation(A @ B + bias)
// activation_type in params: 0=None, 1=ReLU, 2=GELU, 3=SiLU, 4=Sigmoid, 5=Tanh
const TILE_SIZE: u32 = 16u;
var<workgroup> tile_a: array<array<f32, 16>, 16>;
var<workgroup> tile_b: array<array<f32, 16>, 16>;
struct GemmEpilogueParams {
M: u32,
K: u32,
N: u32,
batch_size: u32,
activation_type: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
@group(0) @binding(0) var<storage, read_write> a: array<f32>;
@group(0) @binding(1) var<storage, read_write> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> bias: array<f32>;
@group(0) @binding(3) var<storage, read_write> c: array<f32>;
@group(0) @binding(4) var<uniform> params: GemmEpilogueParams;
fn apply_activation(x: f32, act_type: u32) -> f32 {
switch act_type {
case 1u: {
return max(x, 0.0);
}
case 2u: {
let s = 0.7978845608;
let co = 0.044715;
let inner = s * (x + co * x * x * x);
return 0.5 * x * (1.0 + tanh(inner));
}
case 3u: {
return x / (1.0 + exp(-x));
}
case 4u: {
return 1.0 / (1.0 + exp(-x));
}
case 5u: {
return tanh(x);
}
default: {
return x;
}
}
}
@compute @workgroup_size(16, 16, 1)
fn gemm_bias_act_f32(@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>) {
let M = params.M;
let K = params.K;
let N = params.N;
let row = group_id.y * TILE_SIZE + local_id.y;
let col = group_id.x * TILE_SIZE + local_id.x;
var sum: f32 = 0.0;
let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
let a_col = t * TILE_SIZE + local_id.x;
if (row < M && a_col < K) {
tile_a[local_id.y][local_id.x] = a[row * K + a_col];
} else {
tile_a[local_id.y][local_id.x] = 0.0;
}
let b_row = t * TILE_SIZE + local_id.y;
if (b_row < K && col < N) {
tile_b[local_id.y][local_id.x] = b[b_row * N + col];
} else {
tile_b[local_id.y][local_id.x] = 0.0;
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) {
sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x];
}
workgroupBarrier();
}
if (row < M && col < N) {
c[row * N + col] = apply_activation(sum + bias[col], params.activation_type);
}
}
@compute @workgroup_size(16, 16, 1)
fn gemm_bias_act_batched_f32(@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>) {
let M = params.M;
let K = params.K;
let N = params.N;
let batch = group_id.z;
if (batch >= params.batch_size) { return; }
let row = group_id.y * TILE_SIZE + local_id.y;
let col = group_id.x * TILE_SIZE + local_id.x;
let a_off = batch * M * K;
let b_off = batch * K * N;
let c_off = batch * M * N;
var sum: f32 = 0.0;
let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE;
for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
let a_col = t * TILE_SIZE + local_id.x;
if (row < M && a_col < K) {
tile_a[local_id.y][local_id.x] = a[a_off + row * K + a_col];
} else {
tile_a[local_id.y][local_id.x] = 0.0;
}
let b_row = t * TILE_SIZE + local_id.y;
if (b_row < K && col < N) {
tile_b[local_id.y][local_id.x] = b[b_off + b_row * N + col];
} else {
tile_b[local_id.y][local_id.x] = 0.0;
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) {
sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x];
}
workgroupBarrier();
}
if (row < M && col < N) {
c[c_off + row * N + col] = apply_activation(sum + bias[col], params.activation_type);
}
}