1pub const MATMUL_SHADER: &str = r#"
24const TILE: u32 = 16u;
25
26@group(0) @binding(0) var<storage, read> a: array<f32>;
27@group(0) @binding(1) var<storage, read> b: array<f32>;
28@group(0) @binding(2) var<storage, read_write> c: array<f32>;
29
30struct Dimensions {
31 M: u32, // rows of A and C
32 K: u32, // cols of A, rows of B
33 N: u32, // cols of B and C
34}
35
36@group(0) @binding(3) var<uniform> dims: Dimensions;
37
38// Shared memory tiles — each 16×16 = 256 floats
39var<workgroup> tile_a: array<f32, 256>;
40var<workgroup> tile_b: array<f32, 256>;
41
42// Workgroup size: 16×16 = 256 threads
43@compute @workgroup_size(16, 16)
44fn main(
45 @builtin(global_invocation_id) global_id: vec3<u32>,
46 @builtin(local_invocation_id) local_id: vec3<u32>,
47) {
48 let row = global_id.x;
49 let col = global_id.y;
50 let lr = local_id.x; // local row within tile [0..15]
51 let lc = local_id.y; // local col within tile [0..15]
52
53 var sum: f32 = 0.0;
54
55 // Iterate over K dimension in tiles of 16
56 let num_tiles = (dims.K + TILE - 1u) / TILE;
57
58 for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
59 // Load A tile: A[row, t*TILE + lc]
60 let a_col = t * TILE + lc;
61 if (row < dims.M && a_col < dims.K) {
62 tile_a[lr * TILE + lc] = a[row * dims.K + a_col];
63 } else {
64 tile_a[lr * TILE + lc] = 0.0;
65 }
66
67 // Load B tile: B[t*TILE + lr, col]
68 let b_row = t * TILE + lr;
69 if (b_row < dims.K && col < dims.N) {
70 tile_b[lr * TILE + lc] = b[b_row * dims.N + col];
71 } else {
72 tile_b[lr * TILE + lc] = 0.0;
73 }
74
75 // Wait for all threads to finish loading
76 workgroupBarrier();
77
78 // Accumulate partial dot product from shared memory
79 for (var k: u32 = 0u; k < TILE; k = k + 1u) {
80 sum = sum + tile_a[lr * TILE + k] * tile_b[k * TILE + lc];
81 }
82
83 // Wait before loading next tile (prevents overwriting while others read)
84 workgroupBarrier();
85 }
86
87 // Write result
88 if (row < dims.M && col < dims.N) {
89 c[row * dims.N + col] = sum;
90 }
91}
92"#;
93
94pub const TILED_GEMM_SHADER: &str = r#"
120// CUTLASS-derived tiled GEMM — 64×64 tiles, 4×4 thread micro-tiles
121// Algorithm from NVIDIA CUTLASS (MIT licensed), reimplemented in WGSL.
122
123const BM: u32 = 64u; // thread-block tile M
124const BN: u32 = 64u; // thread-block tile N
125const BK: u32 = 8u; // K-dimension tile step
126const TM: u32 = 4u; // thread micro-tile M (each thread computes 4 rows)
127const TN: u32 = 4u; // thread micro-tile N (each thread computes 4 cols)
128// Workgroup: 16×16 = 256 threads
129// Each thread: 4×4 = 16 output elements
130// Total: 256 threads × 16 = 4096 elements = 64×64 ✓
131
132@group(0) @binding(0) var<storage, read> a: array<f32>;
133@group(0) @binding(1) var<storage, read> b: array<f32>;
134@group(0) @binding(2) var<storage, read_write> c: array<f32>;
135
136struct Dimensions {
137 M: u32,
138 K: u32,
139 N: u32,
140 alpha: f32, // scaling factor (default 1.0)
141}
142
143@group(0) @binding(3) var<uniform> dims: Dimensions;
144
145// Double-buffered shared memory tiles
146// Buffer 0: smem[0..BM*BK] for A, smem[BM*BK..BM*BK+BK*BN] for B
147// Buffer 1: smem[BM*BK+BK*BN..2*(BM*BK+BK*BN)] duplicated
148// Total: 2 * (64*8 + 8*64) * 4 = 2 * 1024 * 4 = 8192 bytes = 8 KB
149var<workgroup> smem_a0: array<f32, 512>; // BM * BK = 64 * 8
150var<workgroup> smem_b0: array<f32, 512>; // BK * BN = 8 * 64
151var<workgroup> smem_a1: array<f32, 512>; // double buffer
152var<workgroup> smem_b1: array<f32, 512>; // double buffer
153
154@compute @workgroup_size(16, 16)
155fn main(
156 @builtin(workgroup_id) wg_id: vec3<u32>,
157 @builtin(local_invocation_id) lid: vec3<u32>,
158) {
159 // Thread position within workgroup (16×16 grid)
160 let tx = lid.x; // [0..15]
161 let ty = lid.y; // [0..15]
162 let tid = ty * 16u + tx; // flat thread index [0..255]
163
164 // This workgroup computes output tile C[bm..bm+64, bn..bn+64]
165 let bm = wg_id.y * BM; // block row offset
166 let bn = wg_id.x * BN; // block col offset
167
168 // Each thread computes a 4×4 micro-tile within the 64×64 block.
169 // Thread (tx, ty) computes rows [ty*4..ty*4+3], cols [tx*4..tx*4+3]
170 let thread_row = ty * TM; // [0, 4, 8, ..., 60]
171 let thread_col = tx * TN; // [0, 4, 8, ..., 60]
172
173 // Accumulator registers: 4×4 = 16 per thread
174 var acc: array<f32, 16>;
175 for (var i = 0u; i < 16u; i++) {
176 acc[i] = 0.0;
177 }
178
179 let num_k_tiles = (dims.K + BK - 1u) / BK;
180
181 // === PROLOGUE: Load first tile into buffer 0 ===
182 // Each thread loads 2 elements of A and 2 elements of B (256 threads × 2 = 512)
183 let load_a_row = tid / BK; // which row of the 64×8 tile
184 let load_a_col = tid % BK; // which col of the 64×8 tile
185 let load_b_row = tid / BN; // which row of the 8×64 tile
186 let load_b_col = tid % BN; // which col of the 8×64 tile
187
188 // Load A[bm + load_a_row, 0 + load_a_col] into smem_a0
189 let ga_row = bm + load_a_row;
190 if (ga_row < dims.M && load_a_col < dims.K) {
191 smem_a0[load_a_row * BK + load_a_col] = a[ga_row * dims.K + load_a_col];
192 } else {
193 smem_a0[load_a_row * BK + load_a_col] = 0.0;
194 }
195 // Second element (tid + 256 maps to rows 32..63 of the 64-row tile)
196 let load_a_row2 = load_a_row + 32u;
197 let ga_row2 = bm + load_a_row2;
198 if (load_a_row2 < BM && ga_row2 < dims.M && load_a_col < dims.K) {
199 smem_a0[load_a_row2 * BK + load_a_col] = a[ga_row2 * dims.K + load_a_col];
200 } else if (load_a_row2 < BM) {
201 smem_a0[load_a_row2 * BK + load_a_col] = 0.0;
202 }
203
204 // Load B[0 + load_b_row, bn + load_b_col] into smem_b0
205 let gb_col = bn + load_b_col;
206 if (load_b_row < dims.K && gb_col < dims.N) {
207 smem_b0[load_b_row * BN + load_b_col] = b[load_b_row * dims.N + gb_col];
208 } else {
209 smem_b0[load_b_row * BN + load_b_col] = 0.0;
210 }
211 // B tile is only 8 rows × 64 cols = 512 elements = exactly 256 threads × 2
212 let load_b_row2 = load_b_row + 4u;
213 if (load_b_row2 < BK && load_b_row2 < dims.K && gb_col < dims.N) {
214 smem_b0[load_b_row2 * BN + load_b_col] = b[load_b_row2 * dims.N + gb_col];
215 } else if (load_b_row2 < BK) {
216 smem_b0[load_b_row2 * BN + load_b_col] = 0.0;
217 }
218
219 workgroupBarrier();
220
221 // === MAINLOOP: iterate over K-dimension tiles ===
222 for (var kt = 0u; kt < num_k_tiles; kt++) {
223 let k_offset = kt * BK;
224
225 // Determine which buffer to read from (ping-pong)
226 let read_buf = kt % 2u;
227
228 // --- Compute 4×4 micro-tile from current shared memory ---
229 for (var k = 0u; k < BK; k++) {
230 // Load 4 A values from shared memory (one column of the micro-tile)
231 var a_frag: array<f32, 4>;
232 var b_frag: array<f32, 4>;
233
234 for (var mi = 0u; mi < TM; mi++) {
235 if (read_buf == 0u) {
236 a_frag[mi] = smem_a0[(thread_row + mi) * BK + k];
237 } else {
238 a_frag[mi] = smem_a1[(thread_row + mi) * BK + k];
239 }
240 }
241 for (var ni = 0u; ni < TN; ni++) {
242 if (read_buf == 0u) {
243 b_frag[ni] = smem_b0[k * BN + thread_col + ni];
244 } else {
245 b_frag[ni] = smem_b1[k * BN + thread_col + ni];
246 }
247 }
248
249 // 4×4 outer product: acc[mi][ni] += a_frag[mi] * b_frag[ni]
250 for (var mi = 0u; mi < TM; mi++) {
251 for (var ni = 0u; ni < TN; ni++) {
252 acc[mi * TN + ni] += a_frag[mi] * b_frag[ni];
253 }
254 }
255 }
256
257 // --- Load NEXT tile into the other buffer (double buffering) ---
258 let next_k = (kt + 1u) * BK;
259 let write_buf = (kt + 1u) % 2u;
260
261 if (kt + 1u < num_k_tiles) {
262 // Load A next tile
263 let na_col = next_k + load_a_col;
264 let na_val = select(0.0, a[ga_row * dims.K + na_col],
265 ga_row < dims.M && na_col < dims.K);
266 if (write_buf == 0u) { smem_a0[load_a_row * BK + load_a_col] = na_val; }
267 else { smem_a1[load_a_row * BK + load_a_col] = na_val; }
268
269 let na_val2 = select(0.0, a[ga_row2 * dims.K + na_col],
270 load_a_row2 < BM && ga_row2 < dims.M && na_col < dims.K);
271 if (load_a_row2 < BM) {
272 if (write_buf == 0u) { smem_a0[load_a_row2 * BK + load_a_col] = na_val2; }
273 else { smem_a1[load_a_row2 * BK + load_a_col] = na_val2; }
274 }
275
276 // Load B next tile
277 let nb_row = next_k + load_b_row;
278 let nb_val = select(0.0, b[nb_row * dims.N + gb_col],
279 nb_row < dims.K && gb_col < dims.N);
280 if (write_buf == 0u) { smem_b0[load_b_row * BN + load_b_col] = nb_val; }
281 else { smem_b1[load_b_row * BN + load_b_col] = nb_val; }
282
283 let nb_row2 = next_k + load_b_row2;
284 if (load_b_row2 < BK) {
285 let nb_val2 = select(0.0, b[nb_row2 * dims.N + gb_col],
286 nb_row2 < dims.K && gb_col < dims.N);
287 if (write_buf == 0u) { smem_b0[load_b_row2 * BN + load_b_col] = nb_val2; }
288 else { smem_b1[load_b_row2 * BN + load_b_col] = nb_val2; }
289 }
290 }
291
292 workgroupBarrier();
293 }
294
295 // === EPILOGUE: Write 4×4 micro-tile to global memory ===
296 let alpha = dims.alpha;
297 for (var mi = 0u; mi < TM; mi++) {
298 for (var ni = 0u; ni < TN; ni++) {
299 let grow = bm + thread_row + mi;
300 let gcol = bn + thread_col + ni;
301 if (grow < dims.M && gcol < dims.N) {
302 c[grow * dims.N + gcol] = alpha * acc[mi * TN + ni];
303 }
304 }
305 }
306}
307"#;
308
309pub const LORA_ADDMM_SHADER: &str = r#"
318@group(0) @binding(0) var<storage, read> input: array<f32>; // [seq, in_dim]
319@group(0) @binding(1) var<storage, read> lora_a: array<f32>; // [in_dim, rank]
320@group(0) @binding(2) var<storage, read> lora_b: array<f32>; // [rank, out_dim]
321@group(0) @binding(3) var<storage, read_write> output: array<f32>; // [seq, out_dim] — ADD to existing
322
323struct LoraParams {
324 seq_len: u32,
325 in_dim: u32,
326 rank: u32,
327 out_dim: u32,
328 scale: f32, // alpha / rank
329 _pad0: u32,
330 _pad1: u32,
331 _pad2: u32,
332}
333
334@group(0) @binding(4) var<uniform> params: LoraParams;
335
336@compute @workgroup_size(256)
337fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
338 let idx = gid.x + gid.y * 65535u * 256u;
339 let total = params.seq_len * params.out_dim;
340 if (idx >= total) { return; }
341
342 let row = idx / params.out_dim;
343 let col = idx % params.out_dim;
344
345 // Compute (input[row] @ A) @ B[col] * scale
346 // First: h = input[row] @ A → [rank] vector
347 // Then: delta = h @ B[:, col] * scale → scalar
348 var delta: f32 = 0.0;
349 for (var r = 0u; r < params.rank; r++) {
350 // h[r] = sum_k input[row, k] * A[k, r]
351 var h_r: f32 = 0.0;
352 for (var k = 0u; k < params.in_dim; k++) {
353 h_r += input[row * params.in_dim + k] * lora_a[k * params.rank + r];
354 }
355 // delta += h[r] * B[r, col]
356 delta += h_r * lora_b[r * params.out_dim + col];
357 }
358
359 output[row * params.out_dim + col] += delta * params.scale;
360}
361"#;
362
363pub const COLUMN_SCATTER_SHADER: &str = r#"
370@group(0) @binding(0) var<storage, read> src: array<f32>;
371@group(0) @binding(1) var<storage, read_write> dst: array<f32>;
372
373struct ScatterParams {
374 seq_len: u32,
375 chunk_n: u32, // width of source
376 full_n: u32, // width of destination
377 col_offset: u32, // column offset in destination
378}
379
380@group(0) @binding(2) var<uniform> params: ScatterParams;
381
382@compute @workgroup_size(256)
383fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
384 let idx = gid.x + gid.y * 65535u * 256u;
385 let total = params.seq_len * params.chunk_n;
386 if (idx >= total) { return; }
387
388 let row = idx / params.chunk_n;
389 let col = idx % params.chunk_n;
390
391 let src_idx = row * params.chunk_n + col;
392 let dst_idx = row * params.full_n + params.col_offset + col;
393
394 dst[dst_idx] = src[src_idx];
395}
396"#;
397
398pub const COLUMN_GATHER_SHADER: &str = r#"
402@group(0) @binding(0) var<storage, read> src: array<f32>;
403@group(0) @binding(1) var<storage, read_write> dst: array<f32>;
404
405struct GatherParams {
406 seq_len: u32,
407 chunk_n: u32, // width of destination
408 full_n: u32, // width of source
409 col_offset: u32, // column offset in source
410}
411
412@group(0) @binding(2) var<uniform> params: GatherParams;
413
414@compute @workgroup_size(256)
415fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
416 let idx = gid.x + gid.y * 65535u * 256u;
417 let total = params.seq_len * params.chunk_n;
418 if (idx >= total) { return; }
419
420 let row = idx / params.chunk_n;
421 let col = idx % params.chunk_n;
422
423 let src_idx = row * params.full_n + params.col_offset + col;
424 let dst_idx = row * params.chunk_n + col;
425
426 dst[dst_idx] = src[src_idx];
427}
428"#;
429
430pub const TRANSPOSE_SHADER: &str = r#"
436@group(0) @binding(0) var<storage, read> src: array<f32>;
437@group(0) @binding(1) var<storage, read_write> dst: array<f32>;
438
439struct TransposeParams {
440 m: u32, // rows of source
441 n: u32, // cols of source
442 scale: f32, // output scaling (1.0 for identity)
443 _pad: u32,
444}
445
446@group(0) @binding(2) var<uniform> params: TransposeParams;
447
448@compute @workgroup_size(256)
449fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
450 let idx = gid.x + gid.y * 65535u * 256u;
451 let total = params.m * params.n;
452 if (idx >= total) { return; }
453
454 let i = idx / params.n; // source row
455 let j = idx % params.n; // source col
456
457 // src[i, j] = src[i * N + j] → dst[j, i] = dst[j * M + i]
458 dst[j * params.m + i] = params.scale * src[i * params.n + j];
459}
460"#;
461
462pub(crate) const GEMV_SHADER: &str = r#"
478@group(0) @binding(0) var<storage, read> x: array<vec4<f32>>; // input [K/4]
479@group(0) @binding(1) var<storage, read> w: array<vec4<f32>>; // weight [N, K/4]
480@group(0) @binding(2) var<storage, read_write> y: array<f32>; // output [N]
481
482struct Params {
483 n: u32, // output dim (number of rows)
484 k: u32, // input dim (K, NOT K/4 — shader divides internally)
485 _pad1: u32,
486 _pad2: u32,
487}
488@group(0) @binding(3) var<uniform> params: Params;
489
490var<workgroup> sdata: array<f32, 256>;
491
492@compute @workgroup_size(256)
493fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
494 @builtin(local_invocation_id) lid: vec3<u32>) {
495 let row = wg_id.x;
496 let tid = lid.x;
497 let k4 = params.k / 4u; // Number of vec4 elements per row
498
499 if (row >= params.n) { return; }
500
501 // Phase 1: vec4 dot product — 4 FMAs per iteration
502 var partial_sum: f32 = 0.0;
503 let row_offset = row * k4;
504 var col4 = tid;
505 while (col4 < k4) {
506 let wv = w[row_offset + col4];
507 let xv = x[col4];
508 partial_sum += dot(wv, xv); // vec4 dot = 4 FMAs
509 col4 += 256u;
510 }
511 sdata[tid] = partial_sum;
512 workgroupBarrier();
513
514 // Phase 2: Tree reduction (256 → 1)
515 if (tid < 128u) { sdata[tid] += sdata[tid + 128u]; }
516 workgroupBarrier();
517 if (tid < 64u) { sdata[tid] += sdata[tid + 64u]; }
518 workgroupBarrier();
519 if (tid < 32u) { sdata[tid] += sdata[tid + 32u]; }
520 workgroupBarrier();
521 if (tid < 16u) { sdata[tid] += sdata[tid + 16u]; }
522 workgroupBarrier();
523 if (tid < 8u) { sdata[tid] += sdata[tid + 8u]; }
524 workgroupBarrier();
525 if (tid < 4u) { sdata[tid] += sdata[tid + 4u]; }
526 workgroupBarrier();
527 if (tid < 2u) { sdata[tid] += sdata[tid + 2u]; }
528 workgroupBarrier();
529 if (tid == 0u) {
530 y[row] = sdata[0] + sdata[1];
531 }
532}
533"#;
534
535pub(crate) const Q4K_GEMV_SHADER: &str = r#"
556// Q4K weights stored as array<u32> (144 bytes = 36 u32s per super-block)
557@group(0) @binding(0) var<storage, read> x: array<f32>; // input [K]
558@group(0) @binding(1) var<storage, read> w_q4k: array<u32>; // Q4K weight bytes as u32
559@group(0) @binding(2) var<storage, read_write> y: array<f32>; // output [N]
560
561struct Q4kParams {
562 n: u32, // output dim (number of rows)
563 k: u32, // input dim (number of columns)
564 num_superblocks: u32, // super-blocks per row = ceil(K / 256)
565 _pad: u32,
566}
567@group(0) @binding(3) var<uniform> params: Q4kParams;
568
569var<workgroup> sdata: array<f32, 256>;
570
571// Extract a u8 from a u32 array (byte-level access)
572fn read_u8(base: u32, byte_offset: u32) -> u32 {
573 let word_idx = base + byte_offset / 4u;
574 let byte_pos = byte_offset % 4u;
575 return (w_q4k[word_idx] >> (byte_pos * 8u)) & 0xFFu;
576}
577
578// Convert f16 (stored as u16 in two bytes) to f32
579// PMAT-497 FIX: Use bitwise IEEE 754 construction (matching CPU f16_to_f32).
580// Previous version used pow(2.0, exp) which introduced rounding errors that
581// corrupted every Q4K scale factor, causing loss > random from step 1.
582fn f16_to_f32(low: u32, high: u32) -> f32 {
583 let bits = low | (high << 8u);
584 let sign = (bits >> 15u) & 1u;
585 let exp = (bits >> 10u) & 0x1Fu;
586 let mantissa = bits & 0x3FFu;
587
588 // Sign bit in f32 position
589 var f32_bits = sign << 31u;
590
591 if (exp == 0u) {
592 if (mantissa == 0u) {
593 // Signed zero
594 return bitcast<f32>(f32_bits);
595 }
596 // Subnormal f16: normalize mantissa to find implicit leading 1
597 var m = mantissa;
598 var e = 0i;
599 while ((m & 0x400u) == 0u) {
600 m = m << 1u;
601 e -= 1i;
602 }
603 // Remove implicit leading 1 and construct f32 bits
604 let new_exp = u32(127 - 15 + 1 + e) << 23u;
605 let new_man = (m & 0x3FFu) << 13u;
606 f32_bits = f32_bits | new_exp | new_man;
607 return bitcast<f32>(f32_bits);
608 }
609 if (exp == 31u) {
610 // Inf/NaN: exponent all-ones in f32
611 f32_bits = f32_bits | (0xFFu << 23u) | (mantissa << 13u);
612 return bitcast<f32>(f32_bits);
613 }
614 // Normal f16: re-bias exponent from f16 (bias=15) to f32 (bias=127)
615 let new_exp = (exp - 15u + 127u) << 23u;
616 let new_man = mantissa << 13u;
617 f32_bits = f32_bits | new_exp | new_man;
618 return bitcast<f32>(f32_bits);
619}
620
621@compute @workgroup_size(256)
622fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
623 @builtin(local_invocation_id) lid: vec3<u32>) {
624 let row = wg_id.x;
625 let tid = lid.x;
626
627 if (row >= params.n) { return; }
628
629 // Each super-block is 36 u32s (144 bytes). Row data starts at:
630 let row_base_u32 = row * params.num_superblocks * 36u;
631
632 var partial_sum: f32 = 0.0;
633
634 // Each thread processes a subset of super-blocks for this row
635 var sb_idx = tid;
636 while (sb_idx < params.num_superblocks) {
637 let sb_base = row_base_u32 + sb_idx * 36u;
638 let input_offset = sb_idx * 256u;
639
640 // Read d and dmin (f16 → f32)
641 let byte0 = read_u8(sb_base, 0u);
642 let byte1 = read_u8(sb_base, 1u);
643 let byte2 = read_u8(sb_base, 2u);
644 let byte3 = read_u8(sb_base, 3u);
645 let d = f16_to_f32(byte0, byte1);
646 let dmin = f16_to_f32(byte2, byte3);
647
648 // Unpack 8 scales and 8 mins from bytes[4:16]
649 var scales: array<f32, 8>;
650 var mins: array<f32, 8>;
651
652 let s0 = read_u8(sb_base, 4u);
653 let s1 = read_u8(sb_base, 5u);
654 let s2 = read_u8(sb_base, 6u);
655 let s3 = read_u8(sb_base, 7u);
656 let m0 = read_u8(sb_base, 8u);
657 let m1 = read_u8(sb_base, 9u);
658 let m2 = read_u8(sb_base, 10u);
659 let m3 = read_u8(sb_base, 11u);
660 let h0 = read_u8(sb_base, 12u);
661 let h1 = read_u8(sb_base, 13u);
662 let h2 = read_u8(sb_base, 14u);
663 let h3 = read_u8(sb_base, 15u);
664
665 scales[0] = f32(s0 & 0x3Fu);
666 scales[1] = f32(s1 & 0x3Fu);
667 scales[2] = f32(s2 & 0x3Fu);
668 scales[3] = f32(s3 & 0x3Fu);
669 scales[4] = f32((h0 & 0x0Fu) | ((s0 >> 6u) << 4u));
670 scales[5] = f32((h1 & 0x0Fu) | ((s1 >> 6u) << 4u));
671 scales[6] = f32((h2 & 0x0Fu) | ((s2 >> 6u) << 4u));
672 scales[7] = f32((h3 & 0x0Fu) | ((s3 >> 6u) << 4u));
673
674 mins[0] = f32(m0 & 0x3Fu);
675 mins[1] = f32(m1 & 0x3Fu);
676 mins[2] = f32(m2 & 0x3Fu);
677 mins[3] = f32(m3 & 0x3Fu);
678 mins[4] = f32((h0 >> 4u) | ((m0 >> 6u) << 4u));
679 mins[5] = f32((h1 >> 4u) | ((m1 >> 6u) << 4u));
680 mins[6] = f32((h2 >> 4u) | ((m2 >> 6u) << 4u));
681 mins[7] = f32((h3 >> 4u) | ((m3 >> 6u) << 4u));
682
683 // Process 4 chunks × 64 elements (32 low nibbles + 32 high nibbles)
684 for (var chunk = 0u; chunk < 4u; chunk++) {
685 let d1 = d * scales[chunk * 2u];
686 let dm1 = dmin * mins[chunk * 2u];
687 let d2 = d * scales[chunk * 2u + 1u];
688 let dm2 = dmin * mins[chunk * 2u + 1u];
689
690 let q_byte_start = 16u + chunk * 32u; // offset into super-block
691 let elem_base = input_offset + chunk * 64u;
692
693 // Low nibbles: 32 elements
694 for (var i = 0u; i < 32u; i++) {
695 let idx = elem_base + i;
696 if (idx < params.k) {
697 let q_byte = read_u8(sb_base, q_byte_start + i);
698 let q_val = f32(q_byte & 0x0Fu);
699 partial_sum += (d1 * q_val - dm1) * x[idx];
700 }
701 }
702 // High nibbles: 32 elements
703 for (var i = 0u; i < 32u; i++) {
704 let idx = elem_base + 32u + i;
705 if (idx < params.k) {
706 let q_byte = read_u8(sb_base, q_byte_start + i);
707 let q_val = f32(q_byte >> 4u);
708 partial_sum += (d2 * q_val - dm2) * x[idx];
709 }
710 }
711 }
712
713 sb_idx += 256u; // stride by workgroup size
714 }
715
716 // Tree reduction (same as GEMV_SHADER)
717 sdata[tid] = partial_sum;
718 workgroupBarrier();
719
720 if (tid < 128u) { sdata[tid] += sdata[tid + 128u]; }
721 workgroupBarrier();
722 if (tid < 64u) { sdata[tid] += sdata[tid + 64u]; }
723 workgroupBarrier();
724 if (tid < 32u) { sdata[tid] += sdata[tid + 32u]; }
725 workgroupBarrier();
726 if (tid < 16u) { sdata[tid] += sdata[tid + 16u]; }
727 workgroupBarrier();
728 if (tid < 8u) { sdata[tid] += sdata[tid + 8u]; }
729 workgroupBarrier();
730 if (tid < 4u) { sdata[tid] += sdata[tid + 4u]; }
731 workgroupBarrier();
732 if (tid < 2u) { sdata[tid] += sdata[tid + 2u]; }
733 workgroupBarrier();
734 if (tid == 0u) {
735 y[row] = sdata[0] + sdata[1];
736 }
737}
738"#;
739
740pub(crate) const VEC_ADD_SHADER: &str = r#"
744@group(0) @binding(0) var<storage, read> a: array<f32>;
745@group(0) @binding(1) var<storage, read> b: array<f32>;
746@group(0) @binding(2) var<storage, read_write> c: array<f32>;
747
748@compute @workgroup_size(256)
749fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
750 let idx = global_id.x;
751 let len = arrayLength(&a);
752
753 if (idx < len) {
754 c[idx] = a[idx] + b[idx];
755 }
756}
757"#;
758
759pub(crate) const VEC_MUL_SHADER: &str = r#"
763@group(0) @binding(0) var<storage, read> a: array<f32>;
764@group(0) @binding(1) var<storage, read> b: array<f32>;
765@group(0) @binding(2) var<storage, read_write> c: array<f32>;
766
767@compute @workgroup_size(256)
768fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
769 let idx = global_id.x;
770 let len = arrayLength(&a);
771
772 if (idx < len) {
773 c[idx] = a[idx] * b[idx];
774 }
775}
776"#;
777
778pub(crate) const VEC_SUB_SHADER: &str = r#"
782@group(0) @binding(0) var<storage, read> a: array<f32>;
783@group(0) @binding(1) var<storage, read> b: array<f32>;
784@group(0) @binding(2) var<storage, read_write> c: array<f32>;
785
786@compute @workgroup_size(256)
787fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
788 let idx = global_id.x;
789 let len = arrayLength(&a);
790
791 if (idx < len) {
792 c[idx] = a[idx] - b[idx];
793 }
794}
795"#;
796
797pub(crate) const SCALE_SHADER: &str = r#"
801@group(0) @binding(0) var<storage, read> input: array<f32>;
802@group(0) @binding(1) var<storage, read_write> output: array<f32>;
803
804struct ScaleParams {
805 scalar: f32,
806}
807
808@group(0) @binding(2) var<uniform> params: ScaleParams;
809
810@compute @workgroup_size(256)
811fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
812 let idx = global_id.x;
813 let len = arrayLength(&input);
814
815 if (idx < len) {
816 output[idx] = input[idx] * params.scalar;
817 }
818}
819"#;
820
821pub(crate) const DOT_PRODUCT_SHADER: &str = r#"
825@group(0) @binding(0) var<storage, read> a: array<f32>;
826@group(0) @binding(1) var<storage, read> b: array<f32>;
827@group(0) @binding(2) var<storage, read_write> result: array<f32>;
828
829var<workgroup> partial_sums: array<f32, 256>;
830
831@compute @workgroup_size(256)
832fn main(
833 @builtin(global_invocation_id) global_id: vec3<u32>,
834 @builtin(local_invocation_id) local_id: vec3<u32>,
835) {
836 let idx = global_id.x;
837 let local_idx = local_id.x;
838 let len = arrayLength(&a);
839
840 // Load and multiply
841 var sum: f32 = 0.0;
842 if (idx < len) {
843 sum = a[idx] * b[idx];
844 }
845 partial_sums[local_idx] = sum;
846
847 workgroupBarrier();
848
849 // Parallel reduction within workgroup
850 var stride: u32 = 128u;
851 while (stride > 0u) {
852 if (local_idx < stride) {
853 partial_sums[local_idx] = partial_sums[local_idx] + partial_sums[local_idx + stride];
854 }
855 stride = stride / 2u;
856 workgroupBarrier();
857 }
858
859 // First thread writes workgroup result
860 if (local_idx == 0u) {
861 result[global_id.x / 256u] = partial_sums[0];
862 }
863}
864"#;
865
866pub(crate) const RELU_SHADER: &str = r#"
873@group(0) @binding(0) var<storage, read> input: array<f32>;
874@group(0) @binding(1) var<storage, read_write> output: array<f32>;
875
876@compute @workgroup_size(256)
877fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
878 let idx = global_id.x;
879 let len = arrayLength(&input);
880
881 if (idx < len) {
882 // ReLU: max(0, x)
883 output[idx] = max(0.0, input[idx]);
884 }
885}
886"#;
887
888pub(crate) const LEAKY_RELU_SHADER: &str = r#"
895@group(0) @binding(0) var<storage, read> input: array<f32>;
896@group(0) @binding(1) var<storage, read_write> output: array<f32>;
897
898struct LeakyReluParams {
899 negative_slope: f32,
900}
901
902@group(0) @binding(2) var<uniform> params: LeakyReluParams;
903
904@compute @workgroup_size(256)
905fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
906 let idx = global_id.x;
907 let len = arrayLength(&input);
908
909 if (idx < len) {
910 let x = input[idx];
911
912 // Leaky ReLU: leaky_relu(x, α) = x if x > 0, else αx
913 if (x > 0.0) {
914 output[idx] = x;
915 } else {
916 output[idx] = params.negative_slope * x;
917 }
918 }
919}
920"#;
921
922pub(crate) const ELU_SHADER: &str = r#"
930@group(0) @binding(0) var<storage, read> input: array<f32>;
931@group(0) @binding(1) var<storage, read_write> output: array<f32>;
932
933struct EluParams {
934 alpha: f32,
935}
936
937@group(0) @binding(2) var<uniform> params: EluParams;
938
939@compute @workgroup_size(256)
940fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
941 let idx = global_id.x;
942 let len = arrayLength(&input);
943
944 if (idx < len) {
945 let x = input[idx];
946
947 // ELU: elu(x, α) = x if x > 0, else α(e^x - 1)
948 if (x > 0.0) {
949 output[idx] = x;
950 } else {
951 output[idx] = params.alpha * (exp(x) - 1.0);
952 }
953 }
954}
955"#;
956
957pub(crate) const SIGMOID_SHADER: &str = r#"
964@group(0) @binding(0) var<storage, read> input: array<f32>;
965@group(0) @binding(1) var<storage, read_write> output: array<f32>;
966
967@compute @workgroup_size(256)
968fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
969 let idx = global_id.x;
970 let len = arrayLength(&input);
971
972 if (idx < len) {
973 let x = input[idx];
974
975 // Sigmoid: σ(x) = 1 / (1 + exp(-x))
976 // Numerically stable implementation:
977 // For x >= 0: σ(x) = 1 / (1 + exp(-x))
978 // For x < 0: σ(x) = exp(x) / (1 + exp(x))
979 var result: f32;
980 if (x >= 0.0) {
981 result = 1.0 / (1.0 + exp(-x));
982 } else {
983 let exp_x = exp(x);
984 result = exp_x / (1.0 + exp_x);
985 }
986
987 output[idx] = result;
988 }
989}
990"#;
991
992pub(crate) const TANH_SHADER: &str = r#"
999@group(0) @binding(0) var<storage, read> input: array<f32>;
1000@group(0) @binding(1) var<storage, read_write> output: array<f32>;
1001
1002@compute @workgroup_size(256)
1003fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1004 let idx = global_id.x;
1005 let len = arrayLength(&input);
1006
1007 if (idx < len) {
1008 let x = input[idx];
1009
1010 // Tanh: tanh(x) = (e^x - e^(-x)) / (e^x + e^(-x))
1011 // = (e^(2x) - 1) / (e^(2x) + 1)
1012 // Numerically stable implementation:
1013 // For |x| > 20: tanh(x) ≈ sign(x) (saturates at ±1)
1014 // Otherwise: use standard formula
1015 var result: f32;
1016 if (x > 20.0) {
1017 result = 1.0;
1018 } else if (x < -20.0) {
1019 result = -1.0;
1020 } else {
1021 let exp_2x = exp(2.0 * x);
1022 result = (exp_2x - 1.0) / (exp_2x + 1.0);
1023 }
1024
1025 output[idx] = result;
1026 }
1027}
1028"#;
1029
1030pub(crate) const SWISH_SHADER: &str = r#"
1037@group(0) @binding(0) var<storage, read> input: array<f32>;
1038@group(0) @binding(1) var<storage, read_write> output: array<f32>;
1039
1040@compute @workgroup_size(256)
1041fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1042 let idx = global_id.x;
1043 let len = arrayLength(&input);
1044
1045 if (idx < len) {
1046 let x = input[idx];
1047
1048 // Swish: swish(x) = x * sigmoid(x) = x / (1 + exp(-x))
1049 // Numerically stable implementation:
1050 // For x >= 0: swish(x) = x / (1 + exp(-x))
1051 // For x < 0: swish(x) = x * exp(x) / (1 + exp(x))
1052 var result: f32;
1053 if (x >= 0.0) {
1054 result = x / (1.0 + exp(-x));
1055 } else {
1056 let exp_x = exp(x);
1057 result = x * exp_x / (1.0 + exp_x);
1058 }
1059
1060 output[idx] = result;
1061 }
1062}
1063"#;
1064
1065pub(crate) const GELU_SHADER: &str = r#"
1073@group(0) @binding(0) var<storage, read> input: array<f32>;
1074@group(0) @binding(1) var<storage, read_write> output: array<f32>;
1075
1076@compute @workgroup_size(256)
1077fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1078 let idx = global_id.x;
1079 let len = arrayLength(&input);
1080
1081 if (idx < len) {
1082 let x = input[idx];
1083
1084 // GELU approximation (tanh-based):
1085 // GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
1086 let SQRT_2_OVER_PI: f32 = 0.7978846; // √(2/π)
1087 let COEFF: f32 = 0.044715;
1088
1089 let x_cubed = x * x * x;
1090 let inner = SQRT_2_OVER_PI * (x + COEFF * x_cubed);
1091 let result = 0.5 * x * (1.0 + tanh(inner));
1092
1093 output[idx] = result;
1094 }
1095}
1096"#;
1097
1098pub(crate) const CLIP_SHADER: &str = r#"
1105@group(0) @binding(0) var<storage, read> input: array<f32>;
1106@group(0) @binding(1) var<storage, read_write> output: array<f32>;
1107
1108struct ClipParams {
1109 min_val: f32,
1110 max_val: f32,
1111}
1112
1113@group(0) @binding(2) var<uniform> params: ClipParams;
1114
1115@compute @workgroup_size(256)
1116fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1117 let idx = global_id.x;
1118 let len = arrayLength(&input);
1119
1120 if (idx < len) {
1121 // Clip: clamp(x, min_val, max_val) = max(min_val, min(max_val, x))
1122 output[idx] = clamp(input[idx], params.min_val, params.max_val);
1123 }
1124}
1125"#;
1126
1127pub(crate) const CONVOLVE2D_SHADER: &str = r#"
1138@group(0) @binding(0) var<storage, read> input: array<f32>;
1139@group(0) @binding(1) var<storage, read> kernel: array<f32>;
1140@group(0) @binding(2) var<storage, read_write> output: array<f32>;
1141
1142struct ConvDimensions {
1143 input_rows: u32,
1144 input_cols: u32,
1145 kernel_rows: u32,
1146 kernel_cols: u32,
1147 output_rows: u32,
1148 output_cols: u32,
1149}
1150
1151@group(0) @binding(3) var<uniform> dims: ConvDimensions;
1152
1153// Workgroup size: 16×16 = 256 threads
1154@compute @workgroup_size(16, 16)
1155fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1156 let out_row = global_id.x;
1157 let out_col = global_id.y;
1158
1159 // Bounds check
1160 if (out_row >= dims.output_rows || out_col >= dims.output_cols) {
1161 return;
1162 }
1163
1164 var sum: f32 = 0.0;
1165
1166 // Apply kernel: iterate over kernel dimensions
1167 for (var k_row: u32 = 0u; k_row < dims.kernel_rows; k_row = k_row + 1u) {
1168 for (var k_col: u32 = 0u; k_col < dims.kernel_cols; k_col = k_col + 1u) {
1169 // Input pixel coordinates
1170 let in_row = out_row + k_row;
1171 let in_col = out_col + k_col;
1172
1173 // Input and kernel are row-major
1174 let input_idx = in_row * dims.input_cols + in_col;
1175 let kernel_idx = k_row * dims.kernel_cols + k_col;
1176
1177 sum = sum + input[input_idx] * kernel[kernel_idx];
1178 }
1179 }
1180
1181 // Write output (row-major)
1182 let output_idx = out_row * dims.output_cols + out_col;
1183 output[output_idx] = sum;
1184}
1185"#;