trueno/backends/gpu/shaders/
backward.rs1pub const SILU_BACKWARD_SHADER: &str = r#"
37@group(0) @binding(0) var<storage, read> input: array<f32>;
38@group(0) @binding(1) var<storage, read> grad_output: array<f32>;
39@group(0) @binding(2) var<storage, read_write> grad_input: array<f32>;
40
41struct Params {
42 n: u32,
43}
44
45@group(0) @binding(3) var<uniform> params: Params;
46
47@compute @workgroup_size(256)
48fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
49 let idx = global_id.x + global_id.y * 65535u * 256u;
50 if (idx >= params.n) {
51 return;
52 }
53
54 let x = input[idx];
55 let grad_out = grad_output[idx];
56
57 // σ(x) = 1 / (1 + exp(-x))
58 let sigmoid_x = 1.0 / (1.0 + exp(-x));
59
60 // y = x * σ(x) (forward output)
61 let y = x * sigmoid_x;
62
63 // silu'(x) = σ(x) * (1 + x - y)
64 let silu_prime = sigmoid_x * (1.0 + x - y);
65
66 grad_input[idx] = grad_out * silu_prime;
67}
68"#;
69
70pub const GEMM_BACKWARD_A_SHADER: &str = r#"
87const TILE: u32 = 16u;
88
89@group(0) @binding(0) var<storage, read> grad_c: array<f32>;
90@group(0) @binding(1) var<storage, read> b: array<f32>;
91@group(0) @binding(2) var<storage, read_write> grad_a: array<f32>;
92
93struct Dimensions {
94 M: u32,
95 K: u32,
96 N: u32,
97}
98
99@group(0) @binding(3) var<uniform> dims: Dimensions;
100
101var<workgroup> tile_gc: array<f32, 256>;
102var<workgroup> tile_bt: array<f32, 256>;
103
104@compute @workgroup_size(16, 16)
105fn main(
106 @builtin(global_invocation_id) global_id: vec3<u32>,
107 @builtin(local_invocation_id) local_id: vec3<u32>,
108) {
109 let row = global_id.x; // M dimension
110 let col = global_id.y; // K dimension
111 let lr = local_id.x;
112 let lc = local_id.y;
113
114 var acc: f32 = 0.0;
115
116 // Tile over N (reduction dimension for dA = dC @ B^T)
117 let num_tiles = (dims.N + TILE - 1u) / TILE;
118
119 for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
120 // Load tile of grad_c[row, t*TILE + lc]
121 let gc_col = t * TILE + lc;
122 if (row < dims.M && gc_col < dims.N) {
123 tile_gc[lr * TILE + lc] = grad_c[row * dims.N + gc_col];
124 } else {
125 tile_gc[lr * TILE + lc] = 0.0;
126 }
127
128 // Load tile of B^T[t*TILE + lr, col] = B[col, t*TILE + lr]
129 // B is stored as B[K,N] row-major, so B[k,n] = b[k*N + n]
130 // B^T[n,k] = B[k,n] = b[k*N + n]
131 let bt_row = t * TILE + lr;
132 if (col < dims.K && bt_row < dims.N) {
133 tile_bt[lr * TILE + lc] = b[col * dims.N + bt_row];
134 } else {
135 tile_bt[lr * TILE + lc] = 0.0;
136 }
137
138 workgroupBarrier();
139
140 // Accumulate: grad_a[row, col] += sum_n grad_c[row, n] * B^T[n, col]
141 for (var k: u32 = 0u; k < TILE; k = k + 1u) {
142 acc += tile_gc[lr * TILE + k] * tile_bt[k * TILE + lc];
143 }
144
145 workgroupBarrier();
146 }
147
148 if (row < dims.M && col < dims.K) {
149 grad_a[row * dims.K + col] = acc;
150 }
151}
152"#;
153
154pub const GEMM_BACKWARD_B_SHADER: &str = r#"
164const TILE: u32 = 16u;
165
166@group(0) @binding(0) var<storage, read> a: array<f32>;
167@group(0) @binding(1) var<storage, read> grad_c: array<f32>;
168@group(0) @binding(2) var<storage, read_write> grad_b: array<f32>;
169
170struct Dimensions {
171 M: u32,
172 K: u32,
173 N: u32,
174}
175
176@group(0) @binding(3) var<uniform> dims: Dimensions;
177
178var<workgroup> tile_at: array<f32, 256>;
179var<workgroup> tile_gc: array<f32, 256>;
180
181@compute @workgroup_size(16, 16)
182fn main(
183 @builtin(global_invocation_id) global_id: vec3<u32>,
184 @builtin(local_invocation_id) local_id: vec3<u32>,
185) {
186 let row = global_id.x; // K dimension
187 let col = global_id.y; // N dimension
188 let lr = local_id.x;
189 let lc = local_id.y;
190
191 var acc: f32 = 0.0;
192
193 // Tile over M (reduction dimension for dB = A^T @ dC)
194 let num_tiles = (dims.M + TILE - 1u) / TILE;
195
196 for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
197 // Load tile of A^T[row, t*TILE + lc] = A[t*TILE + lc, row]
198 let at_col = t * TILE + lc;
199 if (row < dims.K && at_col < dims.M) {
200 tile_at[lr * TILE + lc] = a[at_col * dims.K + row];
201 } else {
202 tile_at[lr * TILE + lc] = 0.0;
203 }
204
205 // Load tile of grad_c[t*TILE + lr, col]
206 let gc_row = t * TILE + lr;
207 if (gc_row < dims.M && col < dims.N) {
208 tile_gc[lr * TILE + lc] = grad_c[gc_row * dims.N + col];
209 } else {
210 tile_gc[lr * TILE + lc] = 0.0;
211 }
212
213 workgroupBarrier();
214
215 for (var k: u32 = 0u; k < TILE; k = k + 1u) {
216 acc += tile_at[lr * TILE + k] * tile_gc[k * TILE + lc];
217 }
218
219 workgroupBarrier();
220 }
221
222 if (row < dims.K && col < dims.N) {
223 grad_b[row * dims.N + col] = acc;
224 }
225}
226"#;
227
228pub const RMSNORM_BACKWARD_SHADER: &str = r#"
248@group(0) @binding(0) var<storage, read> input: array<f32>;
249@group(0) @binding(1) var<storage, read> gamma: array<f32>;
250@group(0) @binding(2) var<storage, read> grad_output: array<f32>;
251@group(0) @binding(3) var<storage, read_write> grad_input: array<f32>;
252@group(0) @binding(4) var<storage, read_write> grad_gamma: array<atomic<u32>>;
253
254struct Params {
255 num_rows: u32,
256 hidden_dim: u32,
257 eps_bits: u32, // f32 eps reinterpreted as u32 (WGSL uniform limitation)
258 _pad: u32,
259}
260
261@group(0) @binding(5) var<uniform> params: Params;
262
263var<workgroup> shared_sum_x2: array<f32, 256>;
264var<workgroup> shared_sum_xgg: array<f32, 256>;
265
266@compute @workgroup_size(256)
267fn main(
268 @builtin(global_invocation_id) global_id: vec3<u32>,
269 @builtin(local_invocation_id) local_id: vec3<u32>,
270 @builtin(workgroup_id) wg_id: vec3<u32>,
271) {
272 let row = wg_id.x;
273 let tid = local_id.x;
274 let h = params.hidden_dim;
275 let eps = bitcast<f32>(params.eps_bits);
276
277 if (row >= params.num_rows) {
278 return;
279 }
280
281 let row_offset = row * h;
282
283 // Pass 1: Compute sum(x²) and sum(x·dL/dy·γ) via stride loop
284 var local_sum_x2: f32 = 0.0;
285 var local_sum_xgg: f32 = 0.0;
286
287 for (var i = tid; i < h; i = i + 256u) {
288 let x_val = input[row_offset + i];
289 let gy_val = grad_output[row_offset + i];
290 let g_val = gamma[i];
291
292 local_sum_x2 += x_val * x_val;
293 local_sum_xgg += x_val * gy_val * g_val;
294 }
295
296 shared_sum_x2[tid] = local_sum_x2;
297 shared_sum_xgg[tid] = local_sum_xgg;
298 workgroupBarrier();
299
300 // Workgroup reduction (256 → 1)
301 for (var stride = 128u; stride > 0u; stride = stride >> 1u) {
302 if (tid < stride) {
303 shared_sum_x2[tid] += shared_sum_x2[tid + stride];
304 shared_sum_xgg[tid] += shared_sum_xgg[tid + stride];
305 }
306 workgroupBarrier();
307 }
308
309 let sum_x2 = shared_sum_x2[0];
310 let sum_xgg = shared_sum_xgg[0];
311
312 // Compute rms and mean_xgg
313 let h_f32 = f32(h);
314 let mean_x2 = sum_x2 / h_f32;
315 let variance_eps = mean_x2 + eps;
316 let rms = sqrt(variance_eps);
317 let inv_rms = 1.0 / rms;
318 let mean_xgg = sum_xgg / h_f32;
319
320 // Pass 2: Compute and store grad_x, accumulate grad_gamma
321 for (var i = tid; i < h; i = i + 256u) {
322 let x_val = input[row_offset + i];
323 let gy_val = grad_output[row_offset + i];
324 let g_val = gamma[i];
325
326 // grad_x = (1/rms) * (γ * dL/dy - x/var_eps * mean_xgg)
327 let gamma_gy = g_val * gy_val;
328 let correction = (x_val / variance_eps) * mean_xgg;
329 let grad_x = inv_rms * (gamma_gy - correction);
330 grad_input[row_offset + i] = grad_x;
331
332 // grad_gamma[i] += dL/dy * x / rms (accumulated across rows via atomic)
333 let gg_contrib = gy_val * x_val * inv_rms;
334 let gg_bits = bitcast<u32>(gg_contrib);
335 // Atomic float add via CAS loop (WGSL doesn't have native atomicAdd for f32)
336 var old_bits = atomicLoad(&grad_gamma[i]);
337 loop {
338 let old_val = bitcast<f32>(old_bits);
339 let new_val = old_val + gg_contrib;
340 let new_bits = bitcast<u32>(new_val);
341 let result = atomicCompareExchangeWeak(&grad_gamma[i], old_bits, new_bits);
342 if (result.exchanged) {
343 break;
344 }
345 old_bits = result.old_value;
346 }
347 }
348}
349"#;
350
351pub const ROPE_BACKWARD_SHADER: &str = r#"
366@group(0) @binding(0) var<storage, read> grad_output: array<f32>;
367@group(0) @binding(1) var<storage, read_write> grad_input: array<f32>;
368
369struct Params {
370 num_heads: u32,
371 head_dim: u32,
372 seq_len: u32,
373 theta_log2: f32, // log2(theta), e.g. log2(10000) ≈ 13.29
374}
375
376@group(0) @binding(2) var<uniform> params: Params;
377
378@compute @workgroup_size(256)
379fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
380 let idx = global_id.x + global_id.y * 65535u * 256u;
381 let half_dim = params.head_dim / 2u;
382 let total_pairs = params.num_heads * params.seq_len * half_dim;
383
384 if (idx >= total_pairs) {
385 return;
386 }
387
388 // Decompose idx into (head, pos, pair)
389 let pair = idx % half_dim;
390 let remaining = idx / half_dim;
391 let pos = remaining % params.seq_len;
392 let head = remaining / params.seq_len;
393
394 // Compute rotation angle: θ_i = pos / θ^(2i/d)
395 let freq_exp = -f32(2u * pair) / f32(params.head_dim) * params.theta_log2;
396 let inv_freq = exp2(freq_exp);
397 let angle = f32(pos) * inv_freq;
398 let cos_angle = cos(angle);
399 let sin_angle = sin(angle);
400
401 // Element indices
402 let base = head * params.seq_len * params.head_dim + pos * params.head_dim;
403 let even_idx = base + 2u * pair;
404 let odd_idx = base + 2u * pair + 1u;
405
406 let dy_even = grad_output[even_idx];
407 let dy_odd = grad_output[odd_idx];
408
409 // Backward rotation (transpose of forward):
410 // dx_even = dy_even * cos + dy_odd * sin
411 // dx_odd = -dy_even * sin + dy_odd * cos
412 grad_input[even_idx] = dy_even * cos_angle + dy_odd * sin_angle;
413 grad_input[odd_idx] = -dy_even * sin_angle + dy_odd * cos_angle;
414}
415"#;
416
417pub const ADAMW_STEP_SHADER: &str = r#"
436@group(0) @binding(0) var<storage, read_write> params: array<f32>;
437@group(0) @binding(1) var<storage, read> grads: array<f32>;
438@group(0) @binding(2) var<storage, read_write> m: array<f32>;
439@group(0) @binding(3) var<storage, read_write> v: array<f32>;
440
441struct AdamWParams {
442 n: u32,
443 lr: f32,
444 beta1: f32,
445 beta2: f32,
446 eps: f32,
447 weight_decay: f32,
448 bias_correction1: f32, // 1 - β1^t
449 bias_correction2: f32, // 1 - β2^t
450}
451
452@group(0) @binding(4) var<uniform> hp: AdamWParams;
453
454@compute @workgroup_size(256)
455fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
456 let idx = global_id.x + global_id.y * 65535u * 256u;
457 if (idx >= hp.n) {
458 return;
459 }
460
461 let g = grads[idx];
462
463 // Update moments
464 m[idx] = hp.beta1 * m[idx] + (1.0 - hp.beta1) * g;
465 v[idx] = hp.beta2 * v[idx] + (1.0 - hp.beta2) * g * g;
466
467 // Bias correction
468 let m_hat = m[idx] / hp.bias_correction1;
469 let v_hat = v[idx] / hp.bias_correction2;
470
471 // Weight decay + parameter update
472 let p = params[idx];
473 params[idx] = p - hp.lr * (m_hat / (sqrt(v_hat) + hp.eps) + hp.weight_decay * p);
474}
475"#;
476
477pub const NF4_DEQUANT_SHADER: &str = r#"
491// NF4 codebook (same as trueno::quantize::NF4_LUT)
492const NF4_LUT: array<f32, 16> = array<f32, 16>(
493 -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
494 -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
495 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
496 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
497);
498
499@group(0) @binding(0) var<storage, read> packed: array<u32>;
500@group(0) @binding(1) var<storage, read> scales: array<f32>;
501@group(0) @binding(2) var<storage, read_write> output: array<f32>;
502
503struct Params {
504 n: u32,
505 block_size: u32,
506}
507
508@group(0) @binding(3) var<uniform> params: Params;
509
510@compute @workgroup_size(256)
511fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
512 // 2D dispatch for large tensors (>16M elements): idx = x + y * 65535 * 256
513 let idx = global_id.x + global_id.y * 65535u * 256u;
514 if (idx >= params.n) {
515 return;
516 }
517
518 // Each byte has 2 nibbles: low nibble = even index, high nibble = odd index
519 let byte_idx = idx / 2u;
520 let packed_val = packed[byte_idx / 4u];
521 let byte_in_u32 = byte_idx % 4u;
522 let byte_val = (packed_val >> (byte_in_u32 * 8u)) & 0xFFu;
523
524 var nibble: u32;
525 if (idx % 2u == 0u) {
526 nibble = byte_val & 0xFu; // low nibble
527 } else {
528 nibble = (byte_val >> 4u) & 0xFu; // high nibble
529 }
530
531 let scale = scales[idx / params.block_size];
532 output[idx] = NF4_LUT[nibble] * scale;
533}
534"#;
535
536pub const CROSS_ENTROPY_FORWARD_SHADER: &str = r#"
555@group(0) @binding(0) var<storage, read> logits: array<f32>; // [seq_len, vocab_size]
556@group(0) @binding(1) var<storage, read> labels: array<u32>; // [seq_len] — target token IDs
557@group(0) @binding(2) var<storage, read_write> losses: array<f32>; // [seq_len] — per-token loss
558@group(0) @binding(3) var<storage, read_write> logsumexp: array<f32>; // [seq_len] — saved for backward
559
560struct CEParams {
561 seq_len: u32,
562 vocab_size: u32,
563 loss_start: u32, // first response token position
564 loss_end: u32, // last+1 response token position
565}
566
567@group(0) @binding(4) var<uniform> params: CEParams;
568
569@compute @workgroup_size(1)
570fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
571 let pos = gid.x;
572 if (pos >= params.seq_len) { return; }
573
574 // Skip non-response positions
575 if (pos < params.loss_start || pos >= params.loss_end) {
576 losses[pos] = 0.0;
577 logsumexp[pos] = 0.0;
578 return;
579 }
580
581 let offset = pos * params.vocab_size;
582 let label = labels[pos];
583
584 // Pass 1: find max for numerical stability
585 var max_val: f32 = -1e30;
586 for (var v = 0u; v < params.vocab_size; v++) {
587 max_val = max(max_val, logits[offset + v]);
588 }
589
590 // Pass 2: compute sum(exp(logit - max))
591 var sum_exp: f32 = 0.0;
592 for (var v = 0u; v < params.vocab_size; v++) {
593 sum_exp += exp(logits[offset + v] - max_val);
594 }
595
596 let lse = max_val + log(sum_exp);
597 logsumexp[pos] = lse;
598
599 // Cross-entropy loss: -logits[label] + logsumexp
600 if (label < params.vocab_size) {
601 losses[pos] = -logits[offset + label] + lse;
602 } else {
603 losses[pos] = 0.0; // padding token
604 }
605}
606"#;
607
608pub const CROSS_ENTROPY_BACKWARD_SHADER: &str = r#"
619@group(0) @binding(0) var<storage, read_write> logits: array<f32>; // [seq_len, vocab_size] — overwritten with gradient
620@group(0) @binding(1) var<storage, read> labels: array<u32>; // [seq_len]
621@group(0) @binding(2) var<storage, read> logsumexp: array<f32>; // [seq_len] — from forward
622
623struct CEBackParams {
624 seq_len: u32,
625 vocab_size: u32,
626 loss_start: u32,
627 loss_end: u32,
628 scale: f32, // 1.0 / num_response_tokens
629}
630
631@group(0) @binding(3) var<uniform> params: CEBackParams;
632
633@compute @workgroup_size(256)
634fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
635 // 2D dispatch for large tensors (seq × vocab > 65535 × 256)
636 let idx = gid.x + gid.y * 65535u * 256u;
637 let total = params.seq_len * params.vocab_size;
638 if (idx >= total) { return; }
639
640 let pos = idx / params.vocab_size;
641 let v = idx % params.vocab_size;
642
643 // Zero gradient for non-response positions
644 if (pos < params.loss_start || pos >= params.loss_end) {
645 logits[idx] = 0.0;
646 return;
647 }
648
649 let lse = logsumexp[pos];
650 let logit = logits[idx];
651
652 // softmax(logit) = exp(logit - logsumexp)
653 var grad = exp(logit - lse);
654
655 // Subtract 1 at the label position
656 let label = labels[pos];
657 if (v == label) {
658 grad -= 1.0;
659 }
660
661 // Scale by 1/num_response_tokens
662 logits[idx] = grad * params.scale;
663}
664"#;