Skip to main content

trueno/backends/gpu/shaders/
advanced.rs

1//! Advanced WGSL shaders: Jacobi eigenvalue, tiled 2D reductions, causal attention.
2
3/// Causal multi-head attention (WGSL) — scaled dot-product with GQA
4///
5/// Computes: attn_out = softmax(Q @ K^T / sqrt(d) + causal_mask) @ V
6///
7/// Supports grouped-query attention (GQA): num_heads >= num_kv_heads.
8/// Each KV head serves `num_heads / num_kv_heads` query heads.
9///
10/// Grid: (num_heads, seq_len, 1) — one workgroup per (head, query position)
11/// Each workgroup computes one row of the attention output.
12///
13/// # Contract (causal-attention-v1)
14///
15/// - Causal mask: position i can only attend to positions [0..i]
16/// - Softmax sums to 1.0 per row (within floating-point tolerance)
17/// - Output shape matches Q shape: [seq_len, num_heads * head_dim]
18pub const CAUSAL_ATTENTION_SHADER: &str = r#"
19// Parallel causal attention — 128 threads cooperatively compute dot products.
20//
21// Each workgroup handles one (head, query_position) pair.
22// 128 threads parallelize the head_dim dot product (128 = head_dim for Qwen).
23// Reduction via shared memory to produce scalar score per K position.
24// Then softmax + weighted V sum, also parallelized across head_dim.
25//
26// Complexity: O(seq × head_dim / 128) per workgroup = 128x faster than sequential.
27
28@group(0) @binding(0) var<storage, read> q: array<f32>;
29@group(0) @binding(1) var<storage, read> k: array<f32>;
30@group(0) @binding(2) var<storage, read> v: array<f32>;
31@group(0) @binding(3) var<storage, read_write> out: array<f32>;
32
33struct AttnParams {
34    seq_len: u32,
35    num_heads: u32,
36    num_kv_heads: u32,
37    head_dim: u32,
38}
39
40@group(0) @binding(4) var<uniform> cfg: AttnParams;
41
42// Shared memory for parallel reduction + softmax weights
43var<workgroup> reduce_buf: array<f32, 128>;  // for dot product reduction
44var<workgroup> weights: array<f32, 2048>;     // softmax weights per K position
45
46@compute @workgroup_size(128)
47fn main(
48    @builtin(global_invocation_id) gid: vec3<u32>,
49    @builtin(local_invocation_id) lid: vec3<u32>,
50) {
51    let head = gid.x / 128u;  // workgroup = head
52    let pos = gid.y;           // workgroup = position
53    let tid = lid.x;           // thread within workgroup [0..127]
54    let seq = cfg.seq_len;
55    let hd = cfg.head_dim;
56    let kv_group = cfg.num_heads / cfg.num_kv_heads;
57    let kv_head = head / kv_group;
58
59    if (head >= cfg.num_heads || pos >= seq) { return; }
60
61    let q_offset = pos * cfg.num_heads * hd + head * hd;
62    let scale = 1.0 / sqrt(f32(hd));
63
64    // Pass 1: compute QK^T scores using parallel dot product
65    var max_score: f32 = -1e30;
66
67    for (var s = 0u; s <= pos; s++) {
68        let k_offset = s * cfg.num_kv_heads * hd + kv_head * hd;
69
70        // Parallel dot product: each thread handles one element of head_dim
71        var partial: f32 = 0.0;
72        if (tid < hd) {
73            partial = q[q_offset + tid] * k[k_offset + tid];
74        }
75        reduce_buf[tid] = partial;
76        workgroupBarrier();
77
78        // Tree reduction to compute full dot product
79        if (hd >= 128u && tid < 64u) { reduce_buf[tid] += reduce_buf[tid + 64u]; }
80        workgroupBarrier();
81        if (tid < 32u) { reduce_buf[tid] += reduce_buf[tid + 32u]; }
82        workgroupBarrier();
83        if (tid < 16u) { reduce_buf[tid] += reduce_buf[tid + 16u]; }
84        workgroupBarrier();
85        if (tid < 8u) { reduce_buf[tid] += reduce_buf[tid + 8u]; }
86        workgroupBarrier();
87        if (tid < 4u) { reduce_buf[tid] += reduce_buf[tid + 4u]; }
88        workgroupBarrier();
89        if (tid < 2u) { reduce_buf[tid] += reduce_buf[tid + 2u]; }
90        workgroupBarrier();
91        if (tid < 1u) { reduce_buf[tid] += reduce_buf[tid + 1u]; }
92        workgroupBarrier();
93
94        // Thread 0 has the full dot product
95        if (tid == 0u) {
96            let score = reduce_buf[0] * scale;
97            weights[s] = score;
98            max_score = max(max_score, score);
99        }
100        workgroupBarrier();
101    }
102
103    // Broadcast max_score to all threads
104    if (tid == 0u) { reduce_buf[0] = max_score; }
105    workgroupBarrier();
106    max_score = reduce_buf[0];
107
108    // Parallel softmax: 128 threads process chunks of the seq positions
109    // Each thread handles ceil(pos/128) positions for exp + partial sum
110    let chunk_size = (pos + 128u) / 128u;
111    let s_start = tid * chunk_size;
112    let s_end = min(s_start + chunk_size, pos + 1u);
113
114    // Parallel exp + partial sum
115    var partial_sum: f32 = 0.0;
116    for (var s = s_start; s < s_end; s++) {
117        let w = exp(weights[s] - max_score);
118        weights[s] = w;
119        partial_sum += w;
120    }
121    reduce_buf[tid] = partial_sum;
122    workgroupBarrier();
123
124    // Reduce partial sums (tree reduction)
125    if (tid < 64u) { reduce_buf[tid] += reduce_buf[tid + 64u]; }
126    workgroupBarrier();
127    if (tid < 32u) { reduce_buf[tid] += reduce_buf[tid + 32u]; }
128    workgroupBarrier();
129    if (tid < 16u) { reduce_buf[tid] += reduce_buf[tid + 16u]; }
130    workgroupBarrier();
131    if (tid < 8u) { reduce_buf[tid] += reduce_buf[tid + 8u]; }
132    workgroupBarrier();
133    if (tid < 4u) { reduce_buf[tid] += reduce_buf[tid + 4u]; }
134    workgroupBarrier();
135    if (tid < 2u) { reduce_buf[tid] += reduce_buf[tid + 2u]; }
136    workgroupBarrier();
137    if (tid < 1u) { reduce_buf[tid] += reduce_buf[tid + 1u]; }
138    workgroupBarrier();
139
140    // Parallel normalize
141    let inv_sum = 1.0 / reduce_buf[0];
142    for (var s = s_start; s < s_end; s++) {
143        weights[s] = weights[s] * inv_sum;
144    }
145    workgroupBarrier();
146
147    // Pass 2: weighted V sum — each thread handles one output dimension
148    if (tid < hd) {
149        let out_offset = pos * cfg.num_heads * hd + head * hd;
150        var val: f32 = 0.0;
151        for (var s = 0u; s <= pos; s++) {
152            let v_offset = s * cfg.num_kv_heads * hd + kv_head * hd;
153            val += weights[s] * v[v_offset + tid];
154        }
155        out[out_offset + tid] = val;
156    }
157}
158"#;
159
160/// Jacobi rotation shader (WGSL) - Apply Givens rotation to matrix columns
161///
162/// Applies rotation to columns p and q of matrix A and eigenvector matrix V:
163/// - A[:,p] = c * A[:,p] - s * A[:,q]
164/// - A[:,q] = s * old_A[:,p] + c * A[:,q]
165///
166/// Same transformation is applied to the V matrix (eigenvectors).
167///
168/// This is a single rotation step in the Jacobi eigenvalue algorithm.
169/// Parallelizes over rows (each thread handles one row).
170pub(crate) const JACOBI_ROTATION_SHADER: &str = r#"
171@group(0) @binding(0) var<storage, read_write> matrix: array<f32>;
172@group(0) @binding(1) var<storage, read_write> eigenvectors: array<f32>;
173
174struct JacobiParams {
175    n: u32,      // Matrix dimension
176    p: u32,      // First column index
177    q: u32,      // Second column index
178    c: f32,      // cos(theta)
179    s: f32,      // sin(theta)
180}
181
182@group(0) @binding(2) var<uniform> params: JacobiParams;
183
184@compute @workgroup_size(256)
185fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
186    let k = global_id.x;
187    let n = params.n;
188    let p = params.p;
189    let q = params.q;
190    let c = params.c;
191    let s = params.s;
192
193    if (k >= n) {
194        return;
195    }
196
197    // Update matrix row k, columns p and q
198    let idx_kp = k * n + p;
199    let idx_kq = k * n + q;
200
201    let akp = matrix[idx_kp];
202    let akq = matrix[idx_kq];
203
204    matrix[idx_kp] = c * akp - s * akq;
205    matrix[idx_kq] = s * akp + c * akq;
206
207    // Update eigenvector matrix row k, columns p and q
208    let vkp = eigenvectors[idx_kp];
209    let vkq = eigenvectors[idx_kq];
210
211    eigenvectors[idx_kp] = c * vkp - s * vkq;
212    eigenvectors[idx_kq] = s * vkp + c * vkq;
213}
214"#;
215
216/// 2D Tiled Sum Reduction compute shader (WGSL)
217///
218/// Computes sum reduction using 16×16 workgroups for optimal memory coalescing.
219/// Phase 1: Each workgroup reduces a tile to partial sums
220/// Phase 2: Combine partial sums (can be done on CPU for small number of workgroups)
221///
222/// This is more efficient than 1D reduction for 2D data (images, matrices)
223/// as it exploits 2D spatial locality in GPU memory hierarchies.
224///
225pub(crate) const TILED_SUM_REDUCTION_SHADER: &str = r#"
226@group(0) @binding(0) var<storage, read> input: array<f32>;
227@group(0) @binding(1) var<storage, read_write> partial_results: array<f32>;
228
229struct Dimensions {
230    width: u32,   // Input width (columns)
231    height: u32,  // Input height (rows)
232}
233
234@group(0) @binding(2) var<uniform> dims: Dimensions;
235
236// 16×16 workgroup shared memory tile
237var<workgroup> tile: array<array<f32, 16>, 16>;
238
239@compute @workgroup_size(16, 16)
240fn main(
241    @builtin(global_invocation_id) global_id: vec3<u32>,
242    @builtin(local_invocation_id) local_id: vec3<u32>,
243    @builtin(workgroup_id) workgroup_id: vec3<u32>,
244    @builtin(num_workgroups) num_workgroups: vec3<u32>,
245) {
246    let lx = local_id.x;
247    let ly = local_id.y;
248    let gx = global_id.x;
249    let gy = global_id.y;
250
251    // Load to shared memory (bounds-checked)
252    var val: f32 = 0.0;
253    if (gx < dims.width && gy < dims.height) {
254        let idx = gy * dims.width + gx;
255        val = input[idx];
256    }
257    tile[ly][lx] = val;
258
259    workgroupBarrier();
260
261    // Row reduction (horizontal): 16 -> 8 -> 4 -> 2 -> 1
262    if (lx < 8u) { tile[ly][lx] = tile[ly][lx] + tile[ly][lx + 8u]; }
263    workgroupBarrier();
264    if (lx < 4u) { tile[ly][lx] = tile[ly][lx] + tile[ly][lx + 4u]; }
265    workgroupBarrier();
266    if (lx < 2u) { tile[ly][lx] = tile[ly][lx] + tile[ly][lx + 2u]; }
267    workgroupBarrier();
268    if (lx < 1u) { tile[ly][lx] = tile[ly][lx] + tile[ly][lx + 1u]; }
269    workgroupBarrier();
270
271    // Column reduction (vertical): first column only, 16 -> 8 -> 4 -> 2 -> 1
272    if (lx == 0u) {
273        if (ly < 8u) { tile[ly][0] = tile[ly][0] + tile[ly + 8u][0]; }
274    }
275    workgroupBarrier();
276    if (lx == 0u) {
277        if (ly < 4u) { tile[ly][0] = tile[ly][0] + tile[ly + 4u][0]; }
278    }
279    workgroupBarrier();
280    if (lx == 0u) {
281        if (ly < 2u) { tile[ly][0] = tile[ly][0] + tile[ly + 2u][0]; }
282    }
283    workgroupBarrier();
284    if (lx == 0u) {
285        if (ly < 1u) { tile[ly][0] = tile[ly][0] + tile[ly + 1u][0]; }
286    }
287
288    // First thread writes workgroup result
289    if (lx == 0u && ly == 0u) {
290        let wg_idx = workgroup_id.y * num_workgroups.x + workgroup_id.x;
291        partial_results[wg_idx] = tile[0][0];
292    }
293}
294"#;
295
296/// 2D Tiled Max Reduction compute shader (WGSL)
297///
298/// Computes max reduction using 16×16 workgroups for optimal memory coalescing.
299/// Same algorithm as tiled sum reduction but with max operation.
300pub(crate) const TILED_MAX_REDUCTION_SHADER: &str = r#"
301@group(0) @binding(0) var<storage, read> input: array<f32>;
302@group(0) @binding(1) var<storage, read_write> partial_results: array<f32>;
303
304struct Dimensions {
305    width: u32,
306    height: u32,
307}
308
309@group(0) @binding(2) var<uniform> dims: Dimensions;
310
311var<workgroup> tile: array<array<f32, 16>, 16>;
312
313@compute @workgroup_size(16, 16)
314fn main(
315    @builtin(global_invocation_id) global_id: vec3<u32>,
316    @builtin(local_invocation_id) local_id: vec3<u32>,
317    @builtin(workgroup_id) workgroup_id: vec3<u32>,
318    @builtin(num_workgroups) num_workgroups: vec3<u32>,
319) {
320    let lx = local_id.x;
321    let ly = local_id.y;
322    let gx = global_id.x;
323    let gy = global_id.y;
324
325    // Load to shared memory (use -inf for out-of-bounds)
326    var val: f32 = -3.402823466e+38; // -FLT_MAX
327    if (gx < dims.width && gy < dims.height) {
328        let idx = gy * dims.width + gx;
329        val = input[idx];
330    }
331    tile[ly][lx] = val;
332
333    workgroupBarrier();
334
335    // Row reduction with max
336    if (lx < 8u) { tile[ly][lx] = max(tile[ly][lx], tile[ly][lx + 8u]); }
337    workgroupBarrier();
338    if (lx < 4u) { tile[ly][lx] = max(tile[ly][lx], tile[ly][lx + 4u]); }
339    workgroupBarrier();
340    if (lx < 2u) { tile[ly][lx] = max(tile[ly][lx], tile[ly][lx + 2u]); }
341    workgroupBarrier();
342    if (lx < 1u) { tile[ly][lx] = max(tile[ly][lx], tile[ly][lx + 1u]); }
343    workgroupBarrier();
344
345    // Column reduction with max
346    if (lx == 0u) {
347        if (ly < 8u) { tile[ly][0] = max(tile[ly][0], tile[ly + 8u][0]); }
348    }
349    workgroupBarrier();
350    if (lx == 0u) {
351        if (ly < 4u) { tile[ly][0] = max(tile[ly][0], tile[ly + 4u][0]); }
352    }
353    workgroupBarrier();
354    if (lx == 0u) {
355        if (ly < 2u) { tile[ly][0] = max(tile[ly][0], tile[ly + 2u][0]); }
356    }
357    workgroupBarrier();
358    if (lx == 0u) {
359        if (ly < 1u) { tile[ly][0] = max(tile[ly][0], tile[ly + 1u][0]); }
360    }
361
362    // First thread writes workgroup result
363    if (lx == 0u && ly == 0u) {
364        let wg_idx = workgroup_id.y * num_workgroups.x + workgroup_id.x;
365        partial_results[wg_idx] = tile[0][0];
366    }
367}
368"#;
369
370/// 2D Tiled Min Reduction compute shader (WGSL)
371///
372/// Computes min reduction using 16×16 workgroups for optimal memory coalescing.
373/// Same algorithm as tiled sum reduction but with min operation.
374pub(crate) const TILED_MIN_REDUCTION_SHADER: &str = r#"
375@group(0) @binding(0) var<storage, read> input: array<f32>;
376@group(0) @binding(1) var<storage, read_write> partial_results: array<f32>;
377
378struct Dimensions {
379    width: u32,
380    height: u32,
381}
382
383@group(0) @binding(2) var<uniform> dims: Dimensions;
384
385var<workgroup> tile: array<array<f32, 16>, 16>;
386
387@compute @workgroup_size(16, 16)
388fn main(
389    @builtin(global_invocation_id) global_id: vec3<u32>,
390    @builtin(local_invocation_id) local_id: vec3<u32>,
391    @builtin(workgroup_id) workgroup_id: vec3<u32>,
392    @builtin(num_workgroups) num_workgroups: vec3<u32>,
393) {
394    let lx = local_id.x;
395    let ly = local_id.y;
396    let gx = global_id.x;
397    let gy = global_id.y;
398
399    // Load to shared memory (use +inf for out-of-bounds)
400    var val: f32 = 3.402823466e+38; // +FLT_MAX
401    if (gx < dims.width && gy < dims.height) {
402        let idx = gy * dims.width + gx;
403        val = input[idx];
404    }
405    tile[ly][lx] = val;
406
407    workgroupBarrier();
408
409    // Row reduction with min
410    if (lx < 8u) { tile[ly][lx] = min(tile[ly][lx], tile[ly][lx + 8u]); }
411    workgroupBarrier();
412    if (lx < 4u) { tile[ly][lx] = min(tile[ly][lx], tile[ly][lx + 4u]); }
413    workgroupBarrier();
414    if (lx < 2u) { tile[ly][lx] = min(tile[ly][lx], tile[ly][lx + 2u]); }
415    workgroupBarrier();
416    if (lx < 1u) { tile[ly][lx] = min(tile[ly][lx], tile[ly][lx + 1u]); }
417    workgroupBarrier();
418
419    // Column reduction with min
420    if (lx == 0u) {
421        if (ly < 8u) { tile[ly][0] = min(tile[ly][0], tile[ly + 8u][0]); }
422    }
423    workgroupBarrier();
424    if (lx == 0u) {
425        if (ly < 4u) { tile[ly][0] = min(tile[ly][0], tile[ly + 4u][0]); }
426    }
427    workgroupBarrier();
428    if (lx == 0u) {
429        if (ly < 2u) { tile[ly][0] = min(tile[ly][0], tile[ly + 2u][0]); }
430    }
431    workgroupBarrier();
432    if (lx == 0u) {
433        if (ly < 1u) { tile[ly][0] = min(tile[ly][0], tile[ly + 1u][0]); }
434    }
435
436    // First thread writes workgroup result
437    if (lx == 0u && ly == 0u) {
438        let wg_idx = workgroup_id.y * num_workgroups.x + workgroup_id.x;
439        partial_results[wg_idx] = tile[0][0];
440    }
441}
442"#;
443
444/// Find max off-diagonal element shader (WGSL) - parallel reduction
445///
446/// Finds the largest absolute off-diagonal element for Jacobi pivot selection.
447/// Returns (max_value, row_index, col_index) packed in result buffer.
448///
449/// Note: Currently unused - pivot selection done on CPU for simplicity.
450/// Future optimization: use this shader for fully GPU-based pivot selection.
451pub(crate) const _JACOBI_MAX_OFFDIAG_SHADER: &str = r#"
452@group(0) @binding(0) var<storage, read> matrix: array<f32>;
453@group(0) @binding(1) var<storage, read_write> result: array<f32>;
454
455struct MatrixParams {
456    n: u32,
457}
458
459@group(0) @binding(2) var<uniform> params: MatrixParams;
460
461// Workgroup shared memory for reduction
462var<workgroup> partial_max: array<f32, 256>;
463var<workgroup> partial_row: array<u32, 256>;
464var<workgroup> partial_col: array<u32, 256>;
465
466@compute @workgroup_size(256)
467fn main(
468    @builtin(global_invocation_id) global_id: vec3<u32>,
469    @builtin(local_invocation_id) local_id: vec3<u32>,
470    @builtin(workgroup_id) workgroup_id: vec3<u32>,
471) {
472    let idx = global_id.x;
473    let local_idx = local_id.x;
474    let n = params.n;
475
476    // Total off-diagonal elements: n*(n-1)/2
477    let total_pairs = n * (n - 1u) / 2u;
478
479    // Convert linear index to (i, j) where i < j
480    var max_val: f32 = 0.0;
481    var max_row: u32 = 0u;
482    var max_col: u32 = 1u;
483
484    if (idx < total_pairs) {
485        // Map linear index to upper triangular (i, j) where i < j
486        // Using quadratic formula inversion
487        var i: u32 = 0u;
488        var j: u32 = 0u;
489        var count: u32 = 0u;
490
491        for (var row: u32 = 0u; row < n - 1u; row = row + 1u) {
492            let pairs_in_row = n - 1u - row;
493            if (count + pairs_in_row > idx) {
494                i = row;
495                j = row + 1u + (idx - count);
496                break;
497            }
498            count = count + pairs_in_row;
499        }
500
501        let aij = matrix[i * n + j];
502        max_val = abs(aij);
503        max_row = i;
504        max_col = j;
505    }
506
507    partial_max[local_idx] = max_val;
508    partial_row[local_idx] = max_row;
509    partial_col[local_idx] = max_col;
510
511    workgroupBarrier();
512
513    // Parallel reduction to find max within workgroup
514    var stride: u32 = 128u;
515    while (stride > 0u) {
516        if (local_idx < stride) {
517            if (partial_max[local_idx + stride] > partial_max[local_idx]) {
518                partial_max[local_idx] = partial_max[local_idx + stride];
519                partial_row[local_idx] = partial_row[local_idx + stride];
520                partial_col[local_idx] = partial_col[local_idx + stride];
521            }
522        }
523        stride = stride / 2u;
524        workgroupBarrier();
525    }
526
527    // First thread writes workgroup result
528    if (local_idx == 0u) {
529        let wg_idx = workgroup_id.x * 3u;
530        result[wg_idx] = partial_max[0];
531        result[wg_idx + 1u] = f32(partial_row[0]);
532        result[wg_idx + 2u] = f32(partial_col[0]);
533    }
534}
535"#;