Skip to main content

trueno/backends/gpu/shaders/
basic_ops.rs

1//! Basic element-wise operations: matmul, add, mul, sub, scale,
2//! dot product, activations, clip, and 2D convolution.
3
4/// Matrix multiplication compute shader (WGSL) — tiled shared memory
5///
6/// Computes C = A × B where:
7/// - A is M×K
8/// - B is K×N
9/// - C is M×N
10///
11/// Uses 16×16 shared memory tiles to reduce global memory bandwidth by ~16×.
12/// Each workgroup loads tiles of A and B into `var<workgroup>` memory, then
13/// computes partial products from shared memory.  This is the standard tiled
14/// matmul from GPU computing textbooks (KAIZEN-021).
15///
16/// # Contract (C-TILED-MATMUL-001)
17///
18/// - **Binding layout**: identical to the naive shader (0=a, 1=b, 2=c, 3=dims)
19/// - **Workgroup size**: 16×16 = 256 threads (unchanged)
20/// - **Dispatch**: ceil(M/16) × ceil(N/16) workgroups (unchanged)
21/// - **Result**: bit-identical to naive shader for all M, K, N (f32 associativity aside)
22/// - **Speedup**: 5–15× on real GPUs (bandwidth-bound → compute-bound)
23pub const MATMUL_SHADER: &str = r#"
24const TILE: u32 = 16u;
25
26@group(0) @binding(0) var<storage, read> a: array<f32>;
27@group(0) @binding(1) var<storage, read> b: array<f32>;
28@group(0) @binding(2) var<storage, read_write> c: array<f32>;
29
30struct Dimensions {
31    M: u32,  // rows of A and C
32    K: u32,  // cols of A, rows of B
33    N: u32,  // cols of B and C
34}
35
36@group(0) @binding(3) var<uniform> dims: Dimensions;
37
38// Shared memory tiles — each 16×16 = 256 floats
39var<workgroup> tile_a: array<f32, 256>;
40var<workgroup> tile_b: array<f32, 256>;
41
42// Workgroup size: 16×16 = 256 threads
43@compute @workgroup_size(16, 16)
44fn main(
45    @builtin(global_invocation_id) global_id: vec3<u32>,
46    @builtin(local_invocation_id) local_id: vec3<u32>,
47) {
48    let row = global_id.x;
49    let col = global_id.y;
50    let lr = local_id.x;  // local row within tile [0..15]
51    let lc = local_id.y;  // local col within tile [0..15]
52
53    var sum: f32 = 0.0;
54
55    // Iterate over K dimension in tiles of 16
56    let num_tiles = (dims.K + TILE - 1u) / TILE;
57
58    for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
59        // Load A tile: A[row, t*TILE + lc]
60        let a_col = t * TILE + lc;
61        if (row < dims.M && a_col < dims.K) {
62            tile_a[lr * TILE + lc] = a[row * dims.K + a_col];
63        } else {
64            tile_a[lr * TILE + lc] = 0.0;
65        }
66
67        // Load B tile: B[t*TILE + lr, col]
68        let b_row = t * TILE + lr;
69        if (b_row < dims.K && col < dims.N) {
70            tile_b[lr * TILE + lc] = b[b_row * dims.N + col];
71        } else {
72            tile_b[lr * TILE + lc] = 0.0;
73        }
74
75        // Wait for all threads to finish loading
76        workgroupBarrier();
77
78        // Accumulate partial dot product from shared memory
79        for (var k: u32 = 0u; k < TILE; k = k + 1u) {
80            sum = sum + tile_a[lr * TILE + k] * tile_b[k * TILE + lc];
81        }
82
83        // Wait before loading next tile (prevents overwriting while others read)
84        workgroupBarrier();
85    }
86
87    // Write result
88    if (row < dims.M && col < dims.N) {
89        c[row * dims.N + col] = sum;
90    }
91}
92"#;
93
94/// CUTLASS-style tiled GEMM compute shader (WGSL) — 64×64 output tiles
95///
96/// Computes C = α·A×B + β·C where A is M×K, B is K×N, C is M×N.
97///
98/// ## CUTLASS-derived tiling (MIT licensed algorithm)
99///
100/// - **Thread-block tile**: 64×64 output, K-step: 8
101/// - **Thread micro-tile**: 4×4 output elements per thread
102/// - **Workgroup**: 16×16 = 256 threads
103/// - **Shared memory**: double-buffered (2 × 64×8 × 4 bytes × 2 matrices = 8 KB)
104/// - **Inner loop**: 4×4 outer product from shared memory per K-step
105/// - **Vectorized loads**: vec4<f32> for coalesced global memory access
106///
107/// ## Performance vs naive 16×16
108///
109/// Each thread computes 16 output elements (4×4) instead of 1, amortizing
110/// shared memory loads by 16x. Double buffering overlaps next tile load
111/// with current tile compute. Expected 10-30x speedup over MATMUL_SHADER.
112///
113/// ## Contract (wgsl-gemm-tiled-v1)
114///
115/// - Binding layout: 0=a, 1=b, 2=c, 3=dims (compatible with MATMUL_SHADER)
116/// - Dispatch: ceil(M/64) × ceil(N/64) workgroups
117/// - Result: matches naive within ε < 1e-4 (f32 reassociation)
118/// - Zero unsafe: entirely via wgpu safe Rust API
119pub const TILED_GEMM_SHADER: &str = r#"
120// CUTLASS-derived tiled GEMM — 64×64 tiles, 4×4 thread micro-tiles
121// Algorithm from NVIDIA CUTLASS (MIT licensed), reimplemented in WGSL.
122
123const BM: u32 = 64u;       // thread-block tile M
124const BN: u32 = 64u;       // thread-block tile N
125const BK: u32 = 8u;        // K-dimension tile step
126const TM: u32 = 4u;        // thread micro-tile M (each thread computes 4 rows)
127const TN: u32 = 4u;        // thread micro-tile N (each thread computes 4 cols)
128// Workgroup: 16×16 = 256 threads
129// Each thread: 4×4 = 16 output elements
130// Total: 256 threads × 16 = 4096 elements = 64×64 ✓
131
132@group(0) @binding(0) var<storage, read> a: array<f32>;
133@group(0) @binding(1) var<storage, read> b: array<f32>;
134@group(0) @binding(2) var<storage, read_write> c: array<f32>;
135
136struct Dimensions {
137    M: u32,
138    K: u32,
139    N: u32,
140    alpha: f32,   // scaling factor (default 1.0)
141}
142
143@group(0) @binding(3) var<uniform> dims: Dimensions;
144
145// Double-buffered shared memory tiles
146// Buffer 0: smem[0..BM*BK] for A, smem[BM*BK..BM*BK+BK*BN] for B
147// Buffer 1: smem[BM*BK+BK*BN..2*(BM*BK+BK*BN)] duplicated
148// Total: 2 * (64*8 + 8*64) * 4 = 2 * 1024 * 4 = 8192 bytes = 8 KB
149var<workgroup> smem_a0: array<f32, 512>;  // BM * BK = 64 * 8
150var<workgroup> smem_b0: array<f32, 512>;  // BK * BN = 8 * 64
151var<workgroup> smem_a1: array<f32, 512>;  // double buffer
152var<workgroup> smem_b1: array<f32, 512>;  // double buffer
153
154@compute @workgroup_size(16, 16)
155fn main(
156    @builtin(workgroup_id) wg_id: vec3<u32>,
157    @builtin(local_invocation_id) lid: vec3<u32>,
158) {
159    // Thread position within workgroup (16×16 grid)
160    let tx = lid.x;  // [0..15]
161    let ty = lid.y;  // [0..15]
162    let tid = ty * 16u + tx;  // flat thread index [0..255]
163
164    // This workgroup computes output tile C[bm..bm+64, bn..bn+64]
165    let bm = wg_id.y * BM;  // block row offset
166    let bn = wg_id.x * BN;  // block col offset
167
168    // Each thread computes a 4×4 micro-tile within the 64×64 block.
169    // Thread (tx, ty) computes rows [ty*4..ty*4+3], cols [tx*4..tx*4+3]
170    let thread_row = ty * TM;  // [0, 4, 8, ..., 60]
171    let thread_col = tx * TN;  // [0, 4, 8, ..., 60]
172
173    // Accumulator registers: 4×4 = 16 per thread
174    var acc: array<f32, 16>;
175    for (var i = 0u; i < 16u; i++) {
176        acc[i] = 0.0;
177    }
178
179    let num_k_tiles = (dims.K + BK - 1u) / BK;
180
181    // === PROLOGUE: Load first tile into buffer 0 ===
182    // Each thread loads 2 elements of A and 2 elements of B (256 threads × 2 = 512)
183    let load_a_row = tid / BK;       // which row of the 64×8 tile
184    let load_a_col = tid % BK;       // which col of the 64×8 tile
185    let load_b_row = tid / BN;       // which row of the 8×64 tile
186    let load_b_col = tid % BN;       // which col of the 8×64 tile
187
188    // Load A[bm + load_a_row, 0 + load_a_col] into smem_a0
189    let ga_row = bm + load_a_row;
190    if (ga_row < dims.M && load_a_col < dims.K) {
191        smem_a0[load_a_row * BK + load_a_col] = a[ga_row * dims.K + load_a_col];
192    } else {
193        smem_a0[load_a_row * BK + load_a_col] = 0.0;
194    }
195    // Second element (tid + 256 maps to rows 32..63 of the 64-row tile)
196    let load_a_row2 = load_a_row + 32u;
197    let ga_row2 = bm + load_a_row2;
198    if (load_a_row2 < BM && ga_row2 < dims.M && load_a_col < dims.K) {
199        smem_a0[load_a_row2 * BK + load_a_col] = a[ga_row2 * dims.K + load_a_col];
200    } else if (load_a_row2 < BM) {
201        smem_a0[load_a_row2 * BK + load_a_col] = 0.0;
202    }
203
204    // Load B[0 + load_b_row, bn + load_b_col] into smem_b0
205    let gb_col = bn + load_b_col;
206    if (load_b_row < dims.K && gb_col < dims.N) {
207        smem_b0[load_b_row * BN + load_b_col] = b[load_b_row * dims.N + gb_col];
208    } else {
209        smem_b0[load_b_row * BN + load_b_col] = 0.0;
210    }
211    // B tile is only 8 rows × 64 cols = 512 elements = exactly 256 threads × 2
212    let load_b_row2 = load_b_row + 4u;
213    if (load_b_row2 < BK && load_b_row2 < dims.K && gb_col < dims.N) {
214        smem_b0[load_b_row2 * BN + load_b_col] = b[load_b_row2 * dims.N + gb_col];
215    } else if (load_b_row2 < BK) {
216        smem_b0[load_b_row2 * BN + load_b_col] = 0.0;
217    }
218
219    workgroupBarrier();
220
221    // === MAINLOOP: iterate over K-dimension tiles ===
222    for (var kt = 0u; kt < num_k_tiles; kt++) {
223        let k_offset = kt * BK;
224
225        // Determine which buffer to read from (ping-pong)
226        let read_buf = kt % 2u;
227
228        // --- Compute 4×4 micro-tile from current shared memory ---
229        for (var k = 0u; k < BK; k++) {
230            // Load 4 A values from shared memory (one column of the micro-tile)
231            var a_frag: array<f32, 4>;
232            var b_frag: array<f32, 4>;
233
234            for (var mi = 0u; mi < TM; mi++) {
235                if (read_buf == 0u) {
236                    a_frag[mi] = smem_a0[(thread_row + mi) * BK + k];
237                } else {
238                    a_frag[mi] = smem_a1[(thread_row + mi) * BK + k];
239                }
240            }
241            for (var ni = 0u; ni < TN; ni++) {
242                if (read_buf == 0u) {
243                    b_frag[ni] = smem_b0[k * BN + thread_col + ni];
244                } else {
245                    b_frag[ni] = smem_b1[k * BN + thread_col + ni];
246                }
247            }
248
249            // 4×4 outer product: acc[mi][ni] += a_frag[mi] * b_frag[ni]
250            for (var mi = 0u; mi < TM; mi++) {
251                for (var ni = 0u; ni < TN; ni++) {
252                    acc[mi * TN + ni] += a_frag[mi] * b_frag[ni];
253                }
254            }
255        }
256
257        // --- Load NEXT tile into the other buffer (double buffering) ---
258        let next_k = (kt + 1u) * BK;
259        let write_buf = (kt + 1u) % 2u;
260
261        if (kt + 1u < num_k_tiles) {
262            // Load A next tile
263            let na_col = next_k + load_a_col;
264            let na_val = select(0.0, a[ga_row * dims.K + na_col],
265                ga_row < dims.M && na_col < dims.K);
266            if (write_buf == 0u) { smem_a0[load_a_row * BK + load_a_col] = na_val; }
267            else { smem_a1[load_a_row * BK + load_a_col] = na_val; }
268
269            let na_val2 = select(0.0, a[ga_row2 * dims.K + na_col],
270                load_a_row2 < BM && ga_row2 < dims.M && na_col < dims.K);
271            if (load_a_row2 < BM) {
272                if (write_buf == 0u) { smem_a0[load_a_row2 * BK + load_a_col] = na_val2; }
273                else { smem_a1[load_a_row2 * BK + load_a_col] = na_val2; }
274            }
275
276            // Load B next tile
277            let nb_row = next_k + load_b_row;
278            let nb_val = select(0.0, b[nb_row * dims.N + gb_col],
279                nb_row < dims.K && gb_col < dims.N);
280            if (write_buf == 0u) { smem_b0[load_b_row * BN + load_b_col] = nb_val; }
281            else { smem_b1[load_b_row * BN + load_b_col] = nb_val; }
282
283            let nb_row2 = next_k + load_b_row2;
284            if (load_b_row2 < BK) {
285                let nb_val2 = select(0.0, b[nb_row2 * dims.N + gb_col],
286                    nb_row2 < dims.K && gb_col < dims.N);
287                if (write_buf == 0u) { smem_b0[load_b_row2 * BN + load_b_col] = nb_val2; }
288                else { smem_b1[load_b_row2 * BN + load_b_col] = nb_val2; }
289            }
290        }
291
292        workgroupBarrier();
293    }
294
295    // === EPILOGUE: Write 4×4 micro-tile to global memory ===
296    let alpha = dims.alpha;
297    for (var mi = 0u; mi < TM; mi++) {
298        for (var ni = 0u; ni < TN; ni++) {
299            let grow = bm + thread_row + mi;
300            let gcol = bn + thread_col + ni;
301            if (grow < dims.M && gcol < dims.N) {
302                c[grow * dims.N + gcol] = alpha * acc[mi * TN + ni];
303            }
304        }
305    }
306}
307"#;
308
309/// Fused LoRA addmm: output += (input @ A) @ B * scale
310///
311/// Computes the LoRA contribution and adds it to the base projection output.
312/// Two matmuls + scaled add in sequence. Uses shared memory for the intermediate.
313/// For rank << hidden_dim, this is much smaller than the base matmul.
314///
315/// Dispatch: one workgroup per output element (output is [seq, out_dim]).
316/// Each thread computes one output element's LoRA delta.
317pub const LORA_ADDMM_SHADER: &str = r#"
318@group(0) @binding(0) var<storage, read> input: array<f32>;   // [seq, in_dim]
319@group(0) @binding(1) var<storage, read> lora_a: array<f32>;  // [in_dim, rank]
320@group(0) @binding(2) var<storage, read> lora_b: array<f32>;  // [rank, out_dim]
321@group(0) @binding(3) var<storage, read_write> output: array<f32>; // [seq, out_dim] — ADD to existing
322
323struct LoraParams {
324    seq_len: u32,
325    in_dim: u32,
326    rank: u32,
327    out_dim: u32,
328    scale: f32,    // alpha / rank
329    _pad0: u32,
330    _pad1: u32,
331    _pad2: u32,
332}
333
334@group(0) @binding(4) var<uniform> params: LoraParams;
335
336@compute @workgroup_size(256)
337fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
338    let idx = gid.x + gid.y * 65535u * 256u;
339    let total = params.seq_len * params.out_dim;
340    if (idx >= total) { return; }
341
342    let row = idx / params.out_dim;
343    let col = idx % params.out_dim;
344
345    // Compute (input[row] @ A) @ B[col] * scale
346    // First: h = input[row] @ A → [rank] vector
347    // Then: delta = h @ B[:, col] * scale → scalar
348    var delta: f32 = 0.0;
349    for (var r = 0u; r < params.rank; r++) {
350        // h[r] = sum_k input[row, k] * A[k, r]
351        var h_r: f32 = 0.0;
352        for (var k = 0u; k < params.in_dim; k++) {
353            h_r += input[row * params.in_dim + k] * lora_a[k * params.rank + r];
354        }
355        // delta += h[r] * B[r, col]
356        delta += h_r * lora_b[r * params.out_dim + col];
357    }
358
359    output[row * params.out_dim + col] += delta * params.scale;
360}
361"#;
362
363/// Column scatter shader — copies chunk columns into a wider row-major matrix.
364///
365/// Replaces N × copy_buffer_to_buffer calls with a single GPU dispatch.
366/// Source: [seq, chunk_n] row-major → Dest: [seq, full_n] at column offset.
367///
368/// Each thread copies one element.
369pub const COLUMN_SCATTER_SHADER: &str = r#"
370@group(0) @binding(0) var<storage, read> src: array<f32>;
371@group(0) @binding(1) var<storage, read_write> dst: array<f32>;
372
373struct ScatterParams {
374    seq_len: u32,
375    chunk_n: u32,    // width of source
376    full_n: u32,     // width of destination
377    col_offset: u32, // column offset in destination
378}
379
380@group(0) @binding(2) var<uniform> params: ScatterParams;
381
382@compute @workgroup_size(256)
383fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
384    let idx = gid.x + gid.y * 65535u * 256u;
385    let total = params.seq_len * params.chunk_n;
386    if (idx >= total) { return; }
387
388    let row = idx / params.chunk_n;
389    let col = idx % params.chunk_n;
390
391    let src_idx = row * params.chunk_n + col;
392    let dst_idx = row * params.full_n + params.col_offset + col;
393
394    dst[dst_idx] = src[src_idx];
395}
396"#;
397
398/// Column gather shader — extracts columns from a wide matrix into a chunk.
399///
400/// Inverse of scatter. Used for backward: extract grad_logits columns per chunk.
401pub const COLUMN_GATHER_SHADER: &str = r#"
402@group(0) @binding(0) var<storage, read> src: array<f32>;
403@group(0) @binding(1) var<storage, read_write> dst: array<f32>;
404
405struct GatherParams {
406    seq_len: u32,
407    chunk_n: u32,    // width of destination
408    full_n: u32,     // width of source
409    col_offset: u32, // column offset in source
410}
411
412@group(0) @binding(2) var<uniform> params: GatherParams;
413
414@compute @workgroup_size(256)
415fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
416    let idx = gid.x + gid.y * 65535u * 256u;
417    let total = params.seq_len * params.chunk_n;
418    if (idx >= total) { return; }
419
420    let row = idx / params.chunk_n;
421    let col = idx % params.chunk_n;
422
423    let src_idx = row * params.full_n + params.col_offset + col;
424    let dst_idx = row * params.chunk_n + col;
425
426    dst[dst_idx] = src[src_idx];
427}
428"#;
429
430/// Scaled transpose: B[j,i] = scale * A[i,j]
431/// Contract: wgsl-transpose-v1
432///
433/// Dispatch: ceil(M*N / 256) workgroups (with 2D for >65535).
434/// Params: { M, N, scale, _pad }
435pub const TRANSPOSE_SHADER: &str = r#"
436@group(0) @binding(0) var<storage, read> src: array<f32>;
437@group(0) @binding(1) var<storage, read_write> dst: array<f32>;
438
439struct TransposeParams {
440    m: u32,      // rows of source
441    n: u32,      // cols of source
442    scale: f32,  // output scaling (1.0 for identity)
443    _pad: u32,
444}
445
446@group(0) @binding(2) var<uniform> params: TransposeParams;
447
448@compute @workgroup_size(256)
449fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
450    let idx = gid.x + gid.y * 65535u * 256u;
451    let total = params.m * params.n;
452    if (idx >= total) { return; }
453
454    let i = idx / params.n;  // source row
455    let j = idx % params.n;  // source col
456
457    // src[i, j] = src[i * N + j]  → dst[j, i] = dst[j * M + i]
458    dst[j * params.m + i] = params.scale * src[i * params.n + j];
459}
460"#;
461
462/// PMAT-326: GEMV compute shader (WGSL) — matrix-vector product y = W × x
463///
464/// Optimized for M=1 (single-token decode). Each workgroup computes ONE output
465/// element by cooperatively reducing the dot product along K using shared memory.
466///
467/// - W: [N, K] row-major weight matrix
468/// - x: [K] input vector
469/// - y: [N] output vector
470///
471/// Workgroup: 256 threads. Each workgroup handles 1 output row.
472/// Dispatch: N workgroups (one per output element).
473/// Reduction: tree reduction in shared memory (log2(256) = 8 steps).
474/// PMAT-331: vec4 vectorized GEMV — 4x fewer memory transactions.
475/// Each thread loads vec4<f32> (4 floats per load), dot4 in registers.
476/// K must be divisible by 4 (true for all Qwen dimensions: 1536, 256, 8960).
477pub(crate) const GEMV_SHADER: &str = r#"
478@group(0) @binding(0) var<storage, read> x: array<vec4<f32>>;     // input [K/4]
479@group(0) @binding(1) var<storage, read> w: array<vec4<f32>>;     // weight [N, K/4]
480@group(0) @binding(2) var<storage, read_write> y: array<f32>;     // output [N]
481
482struct Params {
483    n: u32,  // output dim (number of rows)
484    k: u32,  // input dim (K, NOT K/4 — shader divides internally)
485    _pad1: u32,
486    _pad2: u32,
487}
488@group(0) @binding(3) var<uniform> params: Params;
489
490var<workgroup> sdata: array<f32, 256>;
491
492@compute @workgroup_size(256)
493fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
494        @builtin(local_invocation_id) lid: vec3<u32>) {
495    let row = wg_id.x;
496    let tid = lid.x;
497    let k4 = params.k / 4u;  // Number of vec4 elements per row
498
499    if (row >= params.n) { return; }
500
501    // Phase 1: vec4 dot product — 4 FMAs per iteration
502    var partial_sum: f32 = 0.0;
503    let row_offset = row * k4;
504    var col4 = tid;
505    while (col4 < k4) {
506        let wv = w[row_offset + col4];
507        let xv = x[col4];
508        partial_sum += dot(wv, xv);  // vec4 dot = 4 FMAs
509        col4 += 256u;
510    }
511    sdata[tid] = partial_sum;
512    workgroupBarrier();
513
514    // Phase 2: Tree reduction (256 → 1)
515    if (tid < 128u) { sdata[tid] += sdata[tid + 128u]; }
516    workgroupBarrier();
517    if (tid < 64u) { sdata[tid] += sdata[tid + 64u]; }
518    workgroupBarrier();
519    if (tid < 32u) { sdata[tid] += sdata[tid + 32u]; }
520    workgroupBarrier();
521    if (tid < 16u) { sdata[tid] += sdata[tid + 16u]; }
522    workgroupBarrier();
523    if (tid < 8u) { sdata[tid] += sdata[tid + 8u]; }
524    workgroupBarrier();
525    if (tid < 4u) { sdata[tid] += sdata[tid + 4u]; }
526    workgroupBarrier();
527    if (tid < 2u) { sdata[tid] += sdata[tid + 2u]; }
528    workgroupBarrier();
529    if (tid == 0u) {
530        y[row] = sdata[0] + sdata[1];
531    }
532}
533"#;
534
535/// Q4_K quantized matrix-vector product (WGSL) — C-WGPU-Q4K-001
536///
537/// Computes y[row] = Σ_col dequant(W_q4k[row, col]) × x[col]
538/// where W is stored as raw Q4_K super-blocks (144 bytes → 256 f32 values).
539///
540/// Dequantization happens on-the-fly per-thread — no F32 weight buffer.
541/// This reduces VRAM from 4×num_params (F32) to 144/256×num_params (Q4K) = 7.1x.
542///
543/// Q4_K super-block layout (144 bytes per 256 elements):
544///   bytes[0:2]   = d    (f16, global scale)
545///   bytes[2:4]   = dmin (f16, global min scale)
546///   bytes[4:16]  = 12 packed scale/min bytes (8 sub-blocks, 6-bit packed)
547///   bytes[16:144]= 128 quantized nibble bytes (4-bit, interleaved low/high)
548///
549/// Each sub-block (32 elements): value = d × scale × nibble - dmin × min
550///
551/// Workgroup: 256 threads per output row.
552/// Dispatch: N workgroups (one per output element).
553/// Each thread processes ceil(num_superblocks/256) super-blocks, accumulating
554/// 256 elements per super-block into a partial sum, then tree-reduces.
555pub(crate) const Q4K_GEMV_SHADER: &str = r#"
556// Q4K weights stored as array<u32> (144 bytes = 36 u32s per super-block)
557@group(0) @binding(0) var<storage, read> x: array<f32>;       // input [K]
558@group(0) @binding(1) var<storage, read> w_q4k: array<u32>;   // Q4K weight bytes as u32
559@group(0) @binding(2) var<storage, read_write> y: array<f32>;  // output [N]
560
561struct Q4kParams {
562    n: u32,               // output dim (number of rows)
563    k: u32,               // input dim (number of columns)
564    num_superblocks: u32, // super-blocks per row = ceil(K / 256)
565    _pad: u32,
566}
567@group(0) @binding(3) var<uniform> params: Q4kParams;
568
569var<workgroup> sdata: array<f32, 256>;
570
571// Extract a u8 from a u32 array (byte-level access)
572fn read_u8(base: u32, byte_offset: u32) -> u32 {
573    let word_idx = base + byte_offset / 4u;
574    let byte_pos = byte_offset % 4u;
575    return (w_q4k[word_idx] >> (byte_pos * 8u)) & 0xFFu;
576}
577
578// Convert f16 (stored as u16 in two bytes) to f32
579// PMAT-497 FIX: Use bitwise IEEE 754 construction (matching CPU f16_to_f32).
580// Previous version used pow(2.0, exp) which introduced rounding errors that
581// corrupted every Q4K scale factor, causing loss > random from step 1.
582fn f16_to_f32(low: u32, high: u32) -> f32 {
583    let bits = low | (high << 8u);
584    let sign = (bits >> 15u) & 1u;
585    let exp = (bits >> 10u) & 0x1Fu;
586    let mantissa = bits & 0x3FFu;
587
588    // Sign bit in f32 position
589    var f32_bits = sign << 31u;
590
591    if (exp == 0u) {
592        if (mantissa == 0u) {
593            // Signed zero
594            return bitcast<f32>(f32_bits);
595        }
596        // Subnormal f16: normalize mantissa to find implicit leading 1
597        var m = mantissa;
598        var e = 0i;
599        while ((m & 0x400u) == 0u) {
600            m = m << 1u;
601            e -= 1i;
602        }
603        // Remove implicit leading 1 and construct f32 bits
604        let new_exp = u32(127 - 15 + 1 + e) << 23u;
605        let new_man = (m & 0x3FFu) << 13u;
606        f32_bits = f32_bits | new_exp | new_man;
607        return bitcast<f32>(f32_bits);
608    }
609    if (exp == 31u) {
610        // Inf/NaN: exponent all-ones in f32
611        f32_bits = f32_bits | (0xFFu << 23u) | (mantissa << 13u);
612        return bitcast<f32>(f32_bits);
613    }
614    // Normal f16: re-bias exponent from f16 (bias=15) to f32 (bias=127)
615    let new_exp = (exp - 15u + 127u) << 23u;
616    let new_man = mantissa << 13u;
617    f32_bits = f32_bits | new_exp | new_man;
618    return bitcast<f32>(f32_bits);
619}
620
621@compute @workgroup_size(256)
622fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
623        @builtin(local_invocation_id) lid: vec3<u32>) {
624    let row = wg_id.x;
625    let tid = lid.x;
626
627    if (row >= params.n) { return; }
628
629    // Each super-block is 36 u32s (144 bytes). Row data starts at:
630    let row_base_u32 = row * params.num_superblocks * 36u;
631
632    var partial_sum: f32 = 0.0;
633
634    // Each thread processes a subset of super-blocks for this row
635    var sb_idx = tid;
636    while (sb_idx < params.num_superblocks) {
637        let sb_base = row_base_u32 + sb_idx * 36u;
638        let input_offset = sb_idx * 256u;
639
640        // Read d and dmin (f16 → f32)
641        let byte0 = read_u8(sb_base, 0u);
642        let byte1 = read_u8(sb_base, 1u);
643        let byte2 = read_u8(sb_base, 2u);
644        let byte3 = read_u8(sb_base, 3u);
645        let d = f16_to_f32(byte0, byte1);
646        let dmin = f16_to_f32(byte2, byte3);
647
648        // Unpack 8 scales and 8 mins from bytes[4:16]
649        var scales: array<f32, 8>;
650        var mins: array<f32, 8>;
651
652        let s0 = read_u8(sb_base, 4u);
653        let s1 = read_u8(sb_base, 5u);
654        let s2 = read_u8(sb_base, 6u);
655        let s3 = read_u8(sb_base, 7u);
656        let m0 = read_u8(sb_base, 8u);
657        let m1 = read_u8(sb_base, 9u);
658        let m2 = read_u8(sb_base, 10u);
659        let m3 = read_u8(sb_base, 11u);
660        let h0 = read_u8(sb_base, 12u);
661        let h1 = read_u8(sb_base, 13u);
662        let h2 = read_u8(sb_base, 14u);
663        let h3 = read_u8(sb_base, 15u);
664
665        scales[0] = f32(s0 & 0x3Fu);
666        scales[1] = f32(s1 & 0x3Fu);
667        scales[2] = f32(s2 & 0x3Fu);
668        scales[3] = f32(s3 & 0x3Fu);
669        scales[4] = f32((h0 & 0x0Fu) | ((s0 >> 6u) << 4u));
670        scales[5] = f32((h1 & 0x0Fu) | ((s1 >> 6u) << 4u));
671        scales[6] = f32((h2 & 0x0Fu) | ((s2 >> 6u) << 4u));
672        scales[7] = f32((h3 & 0x0Fu) | ((s3 >> 6u) << 4u));
673
674        mins[0] = f32(m0 & 0x3Fu);
675        mins[1] = f32(m1 & 0x3Fu);
676        mins[2] = f32(m2 & 0x3Fu);
677        mins[3] = f32(m3 & 0x3Fu);
678        mins[4] = f32((h0 >> 4u) | ((m0 >> 6u) << 4u));
679        mins[5] = f32((h1 >> 4u) | ((m1 >> 6u) << 4u));
680        mins[6] = f32((h2 >> 4u) | ((m2 >> 6u) << 4u));
681        mins[7] = f32((h3 >> 4u) | ((m3 >> 6u) << 4u));
682
683        // Process 4 chunks × 64 elements (32 low nibbles + 32 high nibbles)
684        for (var chunk = 0u; chunk < 4u; chunk++) {
685            let d1 = d * scales[chunk * 2u];
686            let dm1 = dmin * mins[chunk * 2u];
687            let d2 = d * scales[chunk * 2u + 1u];
688            let dm2 = dmin * mins[chunk * 2u + 1u];
689
690            let q_byte_start = 16u + chunk * 32u;  // offset into super-block
691            let elem_base = input_offset + chunk * 64u;
692
693            // Low nibbles: 32 elements
694            for (var i = 0u; i < 32u; i++) {
695                let idx = elem_base + i;
696                if (idx < params.k) {
697                    let q_byte = read_u8(sb_base, q_byte_start + i);
698                    let q_val = f32(q_byte & 0x0Fu);
699                    partial_sum += (d1 * q_val - dm1) * x[idx];
700                }
701            }
702            // High nibbles: 32 elements
703            for (var i = 0u; i < 32u; i++) {
704                let idx = elem_base + 32u + i;
705                if (idx < params.k) {
706                    let q_byte = read_u8(sb_base, q_byte_start + i);
707                    let q_val = f32(q_byte >> 4u);
708                    partial_sum += (d2 * q_val - dm2) * x[idx];
709                }
710            }
711        }
712
713        sb_idx += 256u;  // stride by workgroup size
714    }
715
716    // Tree reduction (same as GEMV_SHADER)
717    sdata[tid] = partial_sum;
718    workgroupBarrier();
719
720    if (tid < 128u) { sdata[tid] += sdata[tid + 128u]; }
721    workgroupBarrier();
722    if (tid < 64u) { sdata[tid] += sdata[tid + 64u]; }
723    workgroupBarrier();
724    if (tid < 32u) { sdata[tid] += sdata[tid + 32u]; }
725    workgroupBarrier();
726    if (tid < 16u) { sdata[tid] += sdata[tid + 16u]; }
727    workgroupBarrier();
728    if (tid < 8u) { sdata[tid] += sdata[tid + 8u]; }
729    workgroupBarrier();
730    if (tid < 4u) { sdata[tid] += sdata[tid + 4u]; }
731    workgroupBarrier();
732    if (tid < 2u) { sdata[tid] += sdata[tid + 2u]; }
733    workgroupBarrier();
734    if (tid == 0u) {
735        y[row] = sdata[0] + sdata[1];
736    }
737}
738"#;
739
740/// Vector addition compute shader (WGSL)
741///
742/// Computes c = a + b element-wise
743pub(crate) const VEC_ADD_SHADER: &str = r#"
744@group(0) @binding(0) var<storage, read> a: array<f32>;
745@group(0) @binding(1) var<storage, read> b: array<f32>;
746@group(0) @binding(2) var<storage, read_write> c: array<f32>;
747
748@compute @workgroup_size(256)
749fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
750    let idx = global_id.x;
751    let len = arrayLength(&a);
752
753    if (idx < len) {
754        c[idx] = a[idx] + b[idx];
755    }
756}
757"#;
758
759/// Element-wise multiplication shader (WGSL)
760///
761/// Computes c = a * b element-wise
762pub(crate) const VEC_MUL_SHADER: &str = r#"
763@group(0) @binding(0) var<storage, read> a: array<f32>;
764@group(0) @binding(1) var<storage, read> b: array<f32>;
765@group(0) @binding(2) var<storage, read_write> c: array<f32>;
766
767@compute @workgroup_size(256)
768fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
769    let idx = global_id.x;
770    let len = arrayLength(&a);
771
772    if (idx < len) {
773        c[idx] = a[idx] * b[idx];
774    }
775}
776"#;
777
778/// Element-wise subtraction shader (WGSL)
779///
780/// Computes c = a - b element-wise
781pub(crate) const VEC_SUB_SHADER: &str = r#"
782@group(0) @binding(0) var<storage, read> a: array<f32>;
783@group(0) @binding(1) var<storage, read> b: array<f32>;
784@group(0) @binding(2) var<storage, read_write> c: array<f32>;
785
786@compute @workgroup_size(256)
787fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
788    let idx = global_id.x;
789    let len = arrayLength(&a);
790
791    if (idx < len) {
792        c[idx] = a[idx] - b[idx];
793    }
794}
795"#;
796
797/// Scalar multiplication shader (WGSL)
798///
799/// Computes output = input * scalar element-wise
800pub(crate) const SCALE_SHADER: &str = r#"
801@group(0) @binding(0) var<storage, read> input: array<f32>;
802@group(0) @binding(1) var<storage, read_write> output: array<f32>;
803
804struct ScaleParams {
805    scalar: f32,
806}
807
808@group(0) @binding(2) var<uniform> params: ScaleParams;
809
810@compute @workgroup_size(256)
811fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
812    let idx = global_id.x;
813    let len = arrayLength(&input);
814
815    if (idx < len) {
816        output[idx] = input[idx] * params.scalar;
817    }
818}
819"#;
820
821/// Dot product reduction shader (WGSL)
822///
823/// Computes sum(a[i] * b[i]) using parallel reduction
824pub(crate) const DOT_PRODUCT_SHADER: &str = r#"
825@group(0) @binding(0) var<storage, read> a: array<f32>;
826@group(0) @binding(1) var<storage, read> b: array<f32>;
827@group(0) @binding(2) var<storage, read_write> result: array<f32>;
828
829var<workgroup> partial_sums: array<f32, 256>;
830
831@compute @workgroup_size(256)
832fn main(
833    @builtin(global_invocation_id) global_id: vec3<u32>,
834    @builtin(local_invocation_id) local_id: vec3<u32>,
835) {
836    let idx = global_id.x;
837    let local_idx = local_id.x;
838    let len = arrayLength(&a);
839
840    // Load and multiply
841    var sum: f32 = 0.0;
842    if (idx < len) {
843        sum = a[idx] * b[idx];
844    }
845    partial_sums[local_idx] = sum;
846
847    workgroupBarrier();
848
849    // Parallel reduction within workgroup
850    var stride: u32 = 128u;
851    while (stride > 0u) {
852        if (local_idx < stride) {
853            partial_sums[local_idx] = partial_sums[local_idx] + partial_sums[local_idx + stride];
854        }
855        stride = stride / 2u;
856        workgroupBarrier();
857    }
858
859    // First thread writes workgroup result
860    if (local_idx == 0u) {
861        result[global_id.x / 256u] = partial_sums[0];
862    }
863}
864"#;
865
866/// ReLU activation compute shader (WGSL)
867///
868/// Computes element-wise ReLU: max(0, x)
869///
870/// This is one of the simplest GPU operations - a single comparison and selection per element.
871/// GPU acceleration beneficial for large vectors (>100K elements).
872pub(crate) const RELU_SHADER: &str = r#"
873@group(0) @binding(0) var<storage, read> input: array<f32>;
874@group(0) @binding(1) var<storage, read_write> output: array<f32>;
875
876@compute @workgroup_size(256)
877fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
878    let idx = global_id.x;
879    let len = arrayLength(&input);
880
881    if (idx < len) {
882        // ReLU: max(0, x)
883        output[idx] = max(0.0, input[idx]);
884    }
885}
886"#;
887
888/// Leaky ReLU activation compute shader (WGSL)
889///
890/// Computes element-wise Leaky ReLU: leaky_relu(x, α) = max(αx, x) = x if x > 0, else αx
891///
892/// Leaky ReLU addresses the "dying ReLU" problem by allowing small negative activations.
893/// GPU acceleration beneficial for large vectors (>100K elements).
894pub(crate) const LEAKY_RELU_SHADER: &str = r#"
895@group(0) @binding(0) var<storage, read> input: array<f32>;
896@group(0) @binding(1) var<storage, read_write> output: array<f32>;
897
898struct LeakyReluParams {
899    negative_slope: f32,
900}
901
902@group(0) @binding(2) var<uniform> params: LeakyReluParams;
903
904@compute @workgroup_size(256)
905fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
906    let idx = global_id.x;
907    let len = arrayLength(&input);
908
909    if (idx < len) {
910        let x = input[idx];
911
912        // Leaky ReLU: leaky_relu(x, α) = x if x > 0, else αx
913        if (x > 0.0) {
914            output[idx] = x;
915        } else {
916            output[idx] = params.negative_slope * x;
917        }
918    }
919}
920"#;
921
922/// ELU (Exponential Linear Unit) activation compute shader (WGSL)
923///
924/// Computes element-wise ELU: elu(x, α) = x if x > 0, else α(e^x - 1)
925///
926/// ELU has smooth gradients everywhere and pushes mean activations closer to zero,
927/// improving learning in deep networks.
928/// GPU acceleration beneficial for large vectors (>100K elements).
929pub(crate) const ELU_SHADER: &str = r#"
930@group(0) @binding(0) var<storage, read> input: array<f32>;
931@group(0) @binding(1) var<storage, read_write> output: array<f32>;
932
933struct EluParams {
934    alpha: f32,
935}
936
937@group(0) @binding(2) var<uniform> params: EluParams;
938
939@compute @workgroup_size(256)
940fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
941    let idx = global_id.x;
942    let len = arrayLength(&input);
943
944    if (idx < len) {
945        let x = input[idx];
946
947        // ELU: elu(x, α) = x if x > 0, else α(e^x - 1)
948        if (x > 0.0) {
949            output[idx] = x;
950        } else {
951            output[idx] = params.alpha * (exp(x) - 1.0);
952        }
953    }
954}
955"#;
956
957/// Sigmoid activation compute shader (WGSL)
958///
959/// Computes element-wise sigmoid: σ(x) = 1 / (1 + e^(-x))
960///
961/// Classic logistic function used in binary classification and attention mechanisms.
962/// GPU acceleration beneficial for large vectors (>100K elements).
963pub(crate) const SIGMOID_SHADER: &str = r#"
964@group(0) @binding(0) var<storage, read> input: array<f32>;
965@group(0) @binding(1) var<storage, read_write> output: array<f32>;
966
967@compute @workgroup_size(256)
968fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
969    let idx = global_id.x;
970    let len = arrayLength(&input);
971
972    if (idx < len) {
973        let x = input[idx];
974
975        // Sigmoid: σ(x) = 1 / (1 + exp(-x))
976        // Numerically stable implementation:
977        // For x >= 0: σ(x) = 1 / (1 + exp(-x))
978        // For x < 0: σ(x) = exp(x) / (1 + exp(x))
979        var result: f32;
980        if (x >= 0.0) {
981            result = 1.0 / (1.0 + exp(-x));
982        } else {
983            let exp_x = exp(x);
984            result = exp_x / (1.0 + exp_x);
985        }
986
987        output[idx] = result;
988    }
989}
990"#;
991
992/// Tanh (hyperbolic tangent) activation compute shader (WGSL)
993///
994/// Computes element-wise tanh: tanh(x) = (e^x - e^(-x)) / (e^x + e^(-x))
995///
996/// Classic activation function used in LSTM, GRU, and traditional neural networks.
997/// GPU acceleration beneficial for large vectors (>100K elements).
998pub(crate) const TANH_SHADER: &str = r#"
999@group(0) @binding(0) var<storage, read> input: array<f32>;
1000@group(0) @binding(1) var<storage, read_write> output: array<f32>;
1001
1002@compute @workgroup_size(256)
1003fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1004    let idx = global_id.x;
1005    let len = arrayLength(&input);
1006
1007    if (idx < len) {
1008        let x = input[idx];
1009
1010        // Tanh: tanh(x) = (e^x - e^(-x)) / (e^x + e^(-x))
1011        //                = (e^(2x) - 1) / (e^(2x) + 1)
1012        // Numerically stable implementation:
1013        // For |x| > 20: tanh(x) ≈ sign(x) (saturates at ±1)
1014        // Otherwise: use standard formula
1015        var result: f32;
1016        if (x > 20.0) {
1017            result = 1.0;
1018        } else if (x < -20.0) {
1019            result = -1.0;
1020        } else {
1021            let exp_2x = exp(2.0 * x);
1022            result = (exp_2x - 1.0) / (exp_2x + 1.0);
1023        }
1024
1025        output[idx] = result;
1026    }
1027}
1028"#;
1029
1030/// Swish activation compute shader (WGSL)
1031///
1032/// Computes element-wise swish: swish(x) = x * σ(x) = x / (1 + e^(-x))
1033///
1034/// Modern activation function (SiLU) used in transformers and modern architectures.
1035/// GPU acceleration beneficial for large vectors (>100K elements).
1036pub(crate) const SWISH_SHADER: &str = r#"
1037@group(0) @binding(0) var<storage, read> input: array<f32>;
1038@group(0) @binding(1) var<storage, read_write> output: array<f32>;
1039
1040@compute @workgroup_size(256)
1041fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1042    let idx = global_id.x;
1043    let len = arrayLength(&input);
1044
1045    if (idx < len) {
1046        let x = input[idx];
1047
1048        // Swish: swish(x) = x * sigmoid(x) = x / (1 + exp(-x))
1049        // Numerically stable implementation:
1050        // For x >= 0: swish(x) = x / (1 + exp(-x))
1051        // For x < 0: swish(x) = x * exp(x) / (1 + exp(x))
1052        var result: f32;
1053        if (x >= 0.0) {
1054            result = x / (1.0 + exp(-x));
1055        } else {
1056            let exp_x = exp(x);
1057            result = x * exp_x / (1.0 + exp_x);
1058        }
1059
1060        output[idx] = result;
1061    }
1062}
1063"#;
1064
1065/// GELU activation compute shader (WGSL)
1066///
1067/// Computes element-wise GELU using tanh approximation:
1068/// GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
1069///
1070/// Standard activation in BERT, GPT-2, GPT-3, and modern transformers.
1071/// GPU acceleration beneficial for large vectors (>100K elements).
1072pub(crate) const GELU_SHADER: &str = r#"
1073@group(0) @binding(0) var<storage, read> input: array<f32>;
1074@group(0) @binding(1) var<storage, read_write> output: array<f32>;
1075
1076@compute @workgroup_size(256)
1077fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1078    let idx = global_id.x;
1079    let len = arrayLength(&input);
1080
1081    if (idx < len) {
1082        let x = input[idx];
1083
1084        // GELU approximation (tanh-based):
1085        // GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
1086        let SQRT_2_OVER_PI: f32 = 0.7978846; // √(2/π)
1087        let COEFF: f32 = 0.044715;
1088
1089        let x_cubed = x * x * x;
1090        let inner = SQRT_2_OVER_PI * (x + COEFF * x_cubed);
1091        let result = 0.5 * x * (1.0 + tanh(inner));
1092
1093        output[idx] = result;
1094    }
1095}
1096"#;
1097
1098/// Clip (clamp) compute shader (WGSL)
1099///
1100/// Computes element-wise clip: clamp(x, min_val, max_val)
1101///
1102/// Constrains values to the range [min_val, max_val].
1103/// GPU acceleration beneficial for large vectors (>100K elements).
1104pub(crate) const CLIP_SHADER: &str = r#"
1105@group(0) @binding(0) var<storage, read> input: array<f32>;
1106@group(0) @binding(1) var<storage, read_write> output: array<f32>;
1107
1108struct ClipParams {
1109    min_val: f32,
1110    max_val: f32,
1111}
1112
1113@group(0) @binding(2) var<uniform> params: ClipParams;
1114
1115@compute @workgroup_size(256)
1116fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1117    let idx = global_id.x;
1118    let len = arrayLength(&input);
1119
1120    if (idx < len) {
1121        // Clip: clamp(x, min_val, max_val) = max(min_val, min(max_val, x))
1122        output[idx] = clamp(input[idx], params.min_val, params.max_val);
1123    }
1124}
1125"#;
1126
1127/// 2D Convolution compute shader (WGSL)
1128///
1129/// Computes 2D convolution: output = input ⊗ kernel
1130/// Uses "valid" padding (no padding, output smaller than input)
1131///
1132/// Output dimensions:
1133/// - output_rows = input_rows - kernel_rows + 1
1134/// - output_cols = input_cols - kernel_cols + 1
1135///
1136/// Uses workgroups of 16×16 threads for optimal GPU utilization
1137pub(crate) const CONVOLVE2D_SHADER: &str = r#"
1138@group(0) @binding(0) var<storage, read> input: array<f32>;
1139@group(0) @binding(1) var<storage, read> kernel: array<f32>;
1140@group(0) @binding(2) var<storage, read_write> output: array<f32>;
1141
1142struct ConvDimensions {
1143    input_rows: u32,
1144    input_cols: u32,
1145    kernel_rows: u32,
1146    kernel_cols: u32,
1147    output_rows: u32,
1148    output_cols: u32,
1149}
1150
1151@group(0) @binding(3) var<uniform> dims: ConvDimensions;
1152
1153// Workgroup size: 16×16 = 256 threads
1154@compute @workgroup_size(16, 16)
1155fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1156    let out_row = global_id.x;
1157    let out_col = global_id.y;
1158
1159    // Bounds check
1160    if (out_row >= dims.output_rows || out_col >= dims.output_cols) {
1161        return;
1162    }
1163
1164    var sum: f32 = 0.0;
1165
1166    // Apply kernel: iterate over kernel dimensions
1167    for (var k_row: u32 = 0u; k_row < dims.kernel_rows; k_row = k_row + 1u) {
1168        for (var k_col: u32 = 0u; k_col < dims.kernel_cols; k_col = k_col + 1u) {
1169            // Input pixel coordinates
1170            let in_row = out_row + k_row;
1171            let in_col = out_col + k_col;
1172
1173            // Input and kernel are row-major
1174            let input_idx = in_row * dims.input_cols + in_col;
1175            let kernel_idx = k_row * dims.kernel_cols + k_col;
1176
1177            sum = sum + input[input_idx] * kernel[kernel_idx];
1178        }
1179    }
1180
1181    // Write output (row-major)
1182    let output_idx = out_row * dims.output_cols + out_col;
1183    output[output_idx] = sum;
1184}
1185"#;