trueno/backends/gpu/shaders/
advanced.rs1pub const CAUSAL_ATTENTION_SHADER: &str = r#"
19// Parallel causal attention — 128 threads cooperatively compute dot products.
20//
21// Each workgroup handles one (head, query_position) pair.
22// 128 threads parallelize the head_dim dot product (128 = head_dim for Qwen).
23// Reduction via shared memory to produce scalar score per K position.
24// Then softmax + weighted V sum, also parallelized across head_dim.
25//
26// Complexity: O(seq × head_dim / 128) per workgroup = 128x faster than sequential.
27
28@group(0) @binding(0) var<storage, read> q: array<f32>;
29@group(0) @binding(1) var<storage, read> k: array<f32>;
30@group(0) @binding(2) var<storage, read> v: array<f32>;
31@group(0) @binding(3) var<storage, read_write> out: array<f32>;
32
33struct AttnParams {
34 seq_len: u32,
35 num_heads: u32,
36 num_kv_heads: u32,
37 head_dim: u32,
38}
39
40@group(0) @binding(4) var<uniform> cfg: AttnParams;
41
42// Shared memory for parallel reduction + softmax weights
43var<workgroup> reduce_buf: array<f32, 128>; // for dot product reduction
44var<workgroup> weights: array<f32, 2048>; // softmax weights per K position
45
46@compute @workgroup_size(128)
47fn main(
48 @builtin(global_invocation_id) gid: vec3<u32>,
49 @builtin(local_invocation_id) lid: vec3<u32>,
50) {
51 let head = gid.x / 128u; // workgroup = head
52 let pos = gid.y; // workgroup = position
53 let tid = lid.x; // thread within workgroup [0..127]
54 let seq = cfg.seq_len;
55 let hd = cfg.head_dim;
56 let kv_group = cfg.num_heads / cfg.num_kv_heads;
57 let kv_head = head / kv_group;
58
59 if (head >= cfg.num_heads || pos >= seq) { return; }
60
61 let q_offset = pos * cfg.num_heads * hd + head * hd;
62 let scale = 1.0 / sqrt(f32(hd));
63
64 // Pass 1: compute QK^T scores using parallel dot product
65 var max_score: f32 = -1e30;
66
67 for (var s = 0u; s <= pos; s++) {
68 let k_offset = s * cfg.num_kv_heads * hd + kv_head * hd;
69
70 // Parallel dot product: each thread handles one element of head_dim
71 var partial: f32 = 0.0;
72 if (tid < hd) {
73 partial = q[q_offset + tid] * k[k_offset + tid];
74 }
75 reduce_buf[tid] = partial;
76 workgroupBarrier();
77
78 // Tree reduction to compute full dot product
79 if (hd >= 128u && tid < 64u) { reduce_buf[tid] += reduce_buf[tid + 64u]; }
80 workgroupBarrier();
81 if (tid < 32u) { reduce_buf[tid] += reduce_buf[tid + 32u]; }
82 workgroupBarrier();
83 if (tid < 16u) { reduce_buf[tid] += reduce_buf[tid + 16u]; }
84 workgroupBarrier();
85 if (tid < 8u) { reduce_buf[tid] += reduce_buf[tid + 8u]; }
86 workgroupBarrier();
87 if (tid < 4u) { reduce_buf[tid] += reduce_buf[tid + 4u]; }
88 workgroupBarrier();
89 if (tid < 2u) { reduce_buf[tid] += reduce_buf[tid + 2u]; }
90 workgroupBarrier();
91 if (tid < 1u) { reduce_buf[tid] += reduce_buf[tid + 1u]; }
92 workgroupBarrier();
93
94 // Thread 0 has the full dot product
95 if (tid == 0u) {
96 let score = reduce_buf[0] * scale;
97 weights[s] = score;
98 max_score = max(max_score, score);
99 }
100 workgroupBarrier();
101 }
102
103 // Broadcast max_score to all threads
104 if (tid == 0u) { reduce_buf[0] = max_score; }
105 workgroupBarrier();
106 max_score = reduce_buf[0];
107
108 // Parallel softmax: 128 threads process chunks of the seq positions
109 // Each thread handles ceil(pos/128) positions for exp + partial sum
110 let chunk_size = (pos + 128u) / 128u;
111 let s_start = tid * chunk_size;
112 let s_end = min(s_start + chunk_size, pos + 1u);
113
114 // Parallel exp + partial sum
115 var partial_sum: f32 = 0.0;
116 for (var s = s_start; s < s_end; s++) {
117 let w = exp(weights[s] - max_score);
118 weights[s] = w;
119 partial_sum += w;
120 }
121 reduce_buf[tid] = partial_sum;
122 workgroupBarrier();
123
124 // Reduce partial sums (tree reduction)
125 if (tid < 64u) { reduce_buf[tid] += reduce_buf[tid + 64u]; }
126 workgroupBarrier();
127 if (tid < 32u) { reduce_buf[tid] += reduce_buf[tid + 32u]; }
128 workgroupBarrier();
129 if (tid < 16u) { reduce_buf[tid] += reduce_buf[tid + 16u]; }
130 workgroupBarrier();
131 if (tid < 8u) { reduce_buf[tid] += reduce_buf[tid + 8u]; }
132 workgroupBarrier();
133 if (tid < 4u) { reduce_buf[tid] += reduce_buf[tid + 4u]; }
134 workgroupBarrier();
135 if (tid < 2u) { reduce_buf[tid] += reduce_buf[tid + 2u]; }
136 workgroupBarrier();
137 if (tid < 1u) { reduce_buf[tid] += reduce_buf[tid + 1u]; }
138 workgroupBarrier();
139
140 // Parallel normalize
141 let inv_sum = 1.0 / reduce_buf[0];
142 for (var s = s_start; s < s_end; s++) {
143 weights[s] = weights[s] * inv_sum;
144 }
145 workgroupBarrier();
146
147 // Pass 2: weighted V sum — each thread handles one output dimension
148 if (tid < hd) {
149 let out_offset = pos * cfg.num_heads * hd + head * hd;
150 var val: f32 = 0.0;
151 for (var s = 0u; s <= pos; s++) {
152 let v_offset = s * cfg.num_kv_heads * hd + kv_head * hd;
153 val += weights[s] * v[v_offset + tid];
154 }
155 out[out_offset + tid] = val;
156 }
157}
158"#;
159
160pub(crate) const JACOBI_ROTATION_SHADER: &str = r#"
171@group(0) @binding(0) var<storage, read_write> matrix: array<f32>;
172@group(0) @binding(1) var<storage, read_write> eigenvectors: array<f32>;
173
174struct JacobiParams {
175 n: u32, // Matrix dimension
176 p: u32, // First column index
177 q: u32, // Second column index
178 c: f32, // cos(theta)
179 s: f32, // sin(theta)
180}
181
182@group(0) @binding(2) var<uniform> params: JacobiParams;
183
184@compute @workgroup_size(256)
185fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
186 let k = global_id.x;
187 let n = params.n;
188 let p = params.p;
189 let q = params.q;
190 let c = params.c;
191 let s = params.s;
192
193 if (k >= n) {
194 return;
195 }
196
197 // Update matrix row k, columns p and q
198 let idx_kp = k * n + p;
199 let idx_kq = k * n + q;
200
201 let akp = matrix[idx_kp];
202 let akq = matrix[idx_kq];
203
204 matrix[idx_kp] = c * akp - s * akq;
205 matrix[idx_kq] = s * akp + c * akq;
206
207 // Update eigenvector matrix row k, columns p and q
208 let vkp = eigenvectors[idx_kp];
209 let vkq = eigenvectors[idx_kq];
210
211 eigenvectors[idx_kp] = c * vkp - s * vkq;
212 eigenvectors[idx_kq] = s * vkp + c * vkq;
213}
214"#;
215
216pub(crate) const TILED_SUM_REDUCTION_SHADER: &str = r#"
226@group(0) @binding(0) var<storage, read> input: array<f32>;
227@group(0) @binding(1) var<storage, read_write> partial_results: array<f32>;
228
229struct Dimensions {
230 width: u32, // Input width (columns)
231 height: u32, // Input height (rows)
232}
233
234@group(0) @binding(2) var<uniform> dims: Dimensions;
235
236// 16×16 workgroup shared memory tile
237var<workgroup> tile: array<array<f32, 16>, 16>;
238
239@compute @workgroup_size(16, 16)
240fn main(
241 @builtin(global_invocation_id) global_id: vec3<u32>,
242 @builtin(local_invocation_id) local_id: vec3<u32>,
243 @builtin(workgroup_id) workgroup_id: vec3<u32>,
244 @builtin(num_workgroups) num_workgroups: vec3<u32>,
245) {
246 let lx = local_id.x;
247 let ly = local_id.y;
248 let gx = global_id.x;
249 let gy = global_id.y;
250
251 // Load to shared memory (bounds-checked)
252 var val: f32 = 0.0;
253 if (gx < dims.width && gy < dims.height) {
254 let idx = gy * dims.width + gx;
255 val = input[idx];
256 }
257 tile[ly][lx] = val;
258
259 workgroupBarrier();
260
261 // Row reduction (horizontal): 16 -> 8 -> 4 -> 2 -> 1
262 if (lx < 8u) { tile[ly][lx] = tile[ly][lx] + tile[ly][lx + 8u]; }
263 workgroupBarrier();
264 if (lx < 4u) { tile[ly][lx] = tile[ly][lx] + tile[ly][lx + 4u]; }
265 workgroupBarrier();
266 if (lx < 2u) { tile[ly][lx] = tile[ly][lx] + tile[ly][lx + 2u]; }
267 workgroupBarrier();
268 if (lx < 1u) { tile[ly][lx] = tile[ly][lx] + tile[ly][lx + 1u]; }
269 workgroupBarrier();
270
271 // Column reduction (vertical): first column only, 16 -> 8 -> 4 -> 2 -> 1
272 if (lx == 0u) {
273 if (ly < 8u) { tile[ly][0] = tile[ly][0] + tile[ly + 8u][0]; }
274 }
275 workgroupBarrier();
276 if (lx == 0u) {
277 if (ly < 4u) { tile[ly][0] = tile[ly][0] + tile[ly + 4u][0]; }
278 }
279 workgroupBarrier();
280 if (lx == 0u) {
281 if (ly < 2u) { tile[ly][0] = tile[ly][0] + tile[ly + 2u][0]; }
282 }
283 workgroupBarrier();
284 if (lx == 0u) {
285 if (ly < 1u) { tile[ly][0] = tile[ly][0] + tile[ly + 1u][0]; }
286 }
287
288 // First thread writes workgroup result
289 if (lx == 0u && ly == 0u) {
290 let wg_idx = workgroup_id.y * num_workgroups.x + workgroup_id.x;
291 partial_results[wg_idx] = tile[0][0];
292 }
293}
294"#;
295
296pub(crate) const TILED_MAX_REDUCTION_SHADER: &str = r#"
301@group(0) @binding(0) var<storage, read> input: array<f32>;
302@group(0) @binding(1) var<storage, read_write> partial_results: array<f32>;
303
304struct Dimensions {
305 width: u32,
306 height: u32,
307}
308
309@group(0) @binding(2) var<uniform> dims: Dimensions;
310
311var<workgroup> tile: array<array<f32, 16>, 16>;
312
313@compute @workgroup_size(16, 16)
314fn main(
315 @builtin(global_invocation_id) global_id: vec3<u32>,
316 @builtin(local_invocation_id) local_id: vec3<u32>,
317 @builtin(workgroup_id) workgroup_id: vec3<u32>,
318 @builtin(num_workgroups) num_workgroups: vec3<u32>,
319) {
320 let lx = local_id.x;
321 let ly = local_id.y;
322 let gx = global_id.x;
323 let gy = global_id.y;
324
325 // Load to shared memory (use -inf for out-of-bounds)
326 var val: f32 = -3.402823466e+38; // -FLT_MAX
327 if (gx < dims.width && gy < dims.height) {
328 let idx = gy * dims.width + gx;
329 val = input[idx];
330 }
331 tile[ly][lx] = val;
332
333 workgroupBarrier();
334
335 // Row reduction with max
336 if (lx < 8u) { tile[ly][lx] = max(tile[ly][lx], tile[ly][lx + 8u]); }
337 workgroupBarrier();
338 if (lx < 4u) { tile[ly][lx] = max(tile[ly][lx], tile[ly][lx + 4u]); }
339 workgroupBarrier();
340 if (lx < 2u) { tile[ly][lx] = max(tile[ly][lx], tile[ly][lx + 2u]); }
341 workgroupBarrier();
342 if (lx < 1u) { tile[ly][lx] = max(tile[ly][lx], tile[ly][lx + 1u]); }
343 workgroupBarrier();
344
345 // Column reduction with max
346 if (lx == 0u) {
347 if (ly < 8u) { tile[ly][0] = max(tile[ly][0], tile[ly + 8u][0]); }
348 }
349 workgroupBarrier();
350 if (lx == 0u) {
351 if (ly < 4u) { tile[ly][0] = max(tile[ly][0], tile[ly + 4u][0]); }
352 }
353 workgroupBarrier();
354 if (lx == 0u) {
355 if (ly < 2u) { tile[ly][0] = max(tile[ly][0], tile[ly + 2u][0]); }
356 }
357 workgroupBarrier();
358 if (lx == 0u) {
359 if (ly < 1u) { tile[ly][0] = max(tile[ly][0], tile[ly + 1u][0]); }
360 }
361
362 // First thread writes workgroup result
363 if (lx == 0u && ly == 0u) {
364 let wg_idx = workgroup_id.y * num_workgroups.x + workgroup_id.x;
365 partial_results[wg_idx] = tile[0][0];
366 }
367}
368"#;
369
370pub(crate) const TILED_MIN_REDUCTION_SHADER: &str = r#"
375@group(0) @binding(0) var<storage, read> input: array<f32>;
376@group(0) @binding(1) var<storage, read_write> partial_results: array<f32>;
377
378struct Dimensions {
379 width: u32,
380 height: u32,
381}
382
383@group(0) @binding(2) var<uniform> dims: Dimensions;
384
385var<workgroup> tile: array<array<f32, 16>, 16>;
386
387@compute @workgroup_size(16, 16)
388fn main(
389 @builtin(global_invocation_id) global_id: vec3<u32>,
390 @builtin(local_invocation_id) local_id: vec3<u32>,
391 @builtin(workgroup_id) workgroup_id: vec3<u32>,
392 @builtin(num_workgroups) num_workgroups: vec3<u32>,
393) {
394 let lx = local_id.x;
395 let ly = local_id.y;
396 let gx = global_id.x;
397 let gy = global_id.y;
398
399 // Load to shared memory (use +inf for out-of-bounds)
400 var val: f32 = 3.402823466e+38; // +FLT_MAX
401 if (gx < dims.width && gy < dims.height) {
402 let idx = gy * dims.width + gx;
403 val = input[idx];
404 }
405 tile[ly][lx] = val;
406
407 workgroupBarrier();
408
409 // Row reduction with min
410 if (lx < 8u) { tile[ly][lx] = min(tile[ly][lx], tile[ly][lx + 8u]); }
411 workgroupBarrier();
412 if (lx < 4u) { tile[ly][lx] = min(tile[ly][lx], tile[ly][lx + 4u]); }
413 workgroupBarrier();
414 if (lx < 2u) { tile[ly][lx] = min(tile[ly][lx], tile[ly][lx + 2u]); }
415 workgroupBarrier();
416 if (lx < 1u) { tile[ly][lx] = min(tile[ly][lx], tile[ly][lx + 1u]); }
417 workgroupBarrier();
418
419 // Column reduction with min
420 if (lx == 0u) {
421 if (ly < 8u) { tile[ly][0] = min(tile[ly][0], tile[ly + 8u][0]); }
422 }
423 workgroupBarrier();
424 if (lx == 0u) {
425 if (ly < 4u) { tile[ly][0] = min(tile[ly][0], tile[ly + 4u][0]); }
426 }
427 workgroupBarrier();
428 if (lx == 0u) {
429 if (ly < 2u) { tile[ly][0] = min(tile[ly][0], tile[ly + 2u][0]); }
430 }
431 workgroupBarrier();
432 if (lx == 0u) {
433 if (ly < 1u) { tile[ly][0] = min(tile[ly][0], tile[ly + 1u][0]); }
434 }
435
436 // First thread writes workgroup result
437 if (lx == 0u && ly == 0u) {
438 let wg_idx = workgroup_id.y * num_workgroups.x + workgroup_id.x;
439 partial_results[wg_idx] = tile[0][0];
440 }
441}
442"#;
443
444pub(crate) const _JACOBI_MAX_OFFDIAG_SHADER: &str = r#"
452@group(0) @binding(0) var<storage, read> matrix: array<f32>;
453@group(0) @binding(1) var<storage, read_write> result: array<f32>;
454
455struct MatrixParams {
456 n: u32,
457}
458
459@group(0) @binding(2) var<uniform> params: MatrixParams;
460
461// Workgroup shared memory for reduction
462var<workgroup> partial_max: array<f32, 256>;
463var<workgroup> partial_row: array<u32, 256>;
464var<workgroup> partial_col: array<u32, 256>;
465
466@compute @workgroup_size(256)
467fn main(
468 @builtin(global_invocation_id) global_id: vec3<u32>,
469 @builtin(local_invocation_id) local_id: vec3<u32>,
470 @builtin(workgroup_id) workgroup_id: vec3<u32>,
471) {
472 let idx = global_id.x;
473 let local_idx = local_id.x;
474 let n = params.n;
475
476 // Total off-diagonal elements: n*(n-1)/2
477 let total_pairs = n * (n - 1u) / 2u;
478
479 // Convert linear index to (i, j) where i < j
480 var max_val: f32 = 0.0;
481 var max_row: u32 = 0u;
482 var max_col: u32 = 1u;
483
484 if (idx < total_pairs) {
485 // Map linear index to upper triangular (i, j) where i < j
486 // Using quadratic formula inversion
487 var i: u32 = 0u;
488 var j: u32 = 0u;
489 var count: u32 = 0u;
490
491 for (var row: u32 = 0u; row < n - 1u; row = row + 1u) {
492 let pairs_in_row = n - 1u - row;
493 if (count + pairs_in_row > idx) {
494 i = row;
495 j = row + 1u + (idx - count);
496 break;
497 }
498 count = count + pairs_in_row;
499 }
500
501 let aij = matrix[i * n + j];
502 max_val = abs(aij);
503 max_row = i;
504 max_col = j;
505 }
506
507 partial_max[local_idx] = max_val;
508 partial_row[local_idx] = max_row;
509 partial_col[local_idx] = max_col;
510
511 workgroupBarrier();
512
513 // Parallel reduction to find max within workgroup
514 var stride: u32 = 128u;
515 while (stride > 0u) {
516 if (local_idx < stride) {
517 if (partial_max[local_idx + stride] > partial_max[local_idx]) {
518 partial_max[local_idx] = partial_max[local_idx + stride];
519 partial_row[local_idx] = partial_row[local_idx + stride];
520 partial_col[local_idx] = partial_col[local_idx + stride];
521 }
522 }
523 stride = stride / 2u;
524 workgroupBarrier();
525 }
526
527 // First thread writes workgroup result
528 if (local_idx == 0u) {
529 let wg_idx = workgroup_id.x * 3u;
530 result[wg_idx] = partial_max[0];
531 result[wg_idx + 1u] = f32(partial_row[0]);
532 result[wg_idx + 2u] = f32(partial_col[0]);
533 }
534}
535"#;