Skip to main content

trueno/backends/gpu/shaders/
backward.rs

1//! WGSL backward (gradient) shaders for training
2//!
3//! Contract: wgpu-training-v1.yaml (FALSIFY-WGPU-001)
4//!
5//! Each shader computes gradients for its corresponding forward operation.
6//! All shaders match the CPU reference within ε < 1e-4 (fp32).
7//!
8//! ## Available Backward Shaders
9//!
10//! - [`SILU_BACKWARD_SHADER`]: SiLU activation gradient
11//! - [`GEMM_BACKWARD_A_SHADER`]: dL/dA = dL/dC @ B^T
12//! - [`GEMM_BACKWARD_B_SHADER`]: dL/dB = A^T @ dL/dC
13//! - [`RMSNORM_BACKWARD_SHADER`]: RMSNorm gradient (dx, dγ)
14//! - [`ROPE_BACKWARD_SHADER`]: RoPE gradient (negated sin rotation)
15//! - [`SOFTMAX_BACKWARD_SHADER`]: Softmax Jacobian-vector product
16//! - [`CROSS_ENTROPY_BACKWARD_SHADER`]: Fused log-softmax + NLL gradient
17//! - [`ADAMW_STEP_SHADER`]: AdamW optimizer step
18//! - [`NF4_DEQUANT_SHADER`]: NF4 4-bit weight dequantization
19
20// ============================================================================
21// SiLU Backward: grad_x = grad_out * (σ(x) + x·σ(x)·(1 - σ(x)))
22//              = grad_out * σ(x) * (1 + x - x·σ(x))
23//
24// Reference: Elfwing et al., "Sigmoid-Weighted Linear Units" (arXiv:1702.03118)
25// ============================================================================
26
27/// SiLU (Swish) backward shader
28///
29/// Forward: y = x * σ(x)
30/// Backward: dy/dx = σ(x) * (1 + x - y) where y = x * σ(x)
31///
32/// Binding 0: input x (read)
33/// Binding 1: grad_output dL/dy (read)
34/// Binding 2: grad_input dL/dx (write)
35/// Binding 3: uniform { n: u32 }
36pub const SILU_BACKWARD_SHADER: &str = r#"
37@group(0) @binding(0) var<storage, read> input: array<f32>;
38@group(0) @binding(1) var<storage, read> grad_output: array<f32>;
39@group(0) @binding(2) var<storage, read_write> grad_input: array<f32>;
40
41struct Params {
42    n: u32,
43}
44
45@group(0) @binding(3) var<uniform> params: Params;
46
47@compute @workgroup_size(256)
48fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
49    let idx = global_id.x + global_id.y * 65535u * 256u;
50    if (idx >= params.n) {
51        return;
52    }
53
54    let x = input[idx];
55    let grad_out = grad_output[idx];
56
57    // σ(x) = 1 / (1 + exp(-x))
58    let sigmoid_x = 1.0 / (1.0 + exp(-x));
59
60    // y = x * σ(x) (forward output)
61    let y = x * sigmoid_x;
62
63    // silu'(x) = σ(x) * (1 + x - y)
64    let silu_prime = sigmoid_x * (1.0 + x - y);
65
66    grad_input[idx] = grad_out * silu_prime;
67}
68"#;
69
70// ============================================================================
71// GEMM Backward A: grad_A[M,K] = grad_C[M,N] @ B^T[N,K]
72//
73// Reuses tiled matmul pattern. The "B transposed" is handled by swapping
74// the indexing: B[j,i] instead of B[i,j].
75// ============================================================================
76
77/// GEMM backward for A: dL/dA = dL/dC @ B^T
78///
79/// Forward: C[M,N] = A[M,K] @ B[K,N]
80/// Backward: dL/dA[M,K] = dL/dC[M,N] @ B^T[N,K]
81///
82/// Binding 0: grad_c (dL/dC) [M*N] (read)
83/// Binding 1: b [K*N] (read) — accessed transposed
84/// Binding 2: grad_a (dL/dA) [M*K] (write)
85/// Binding 3: uniform { M, K, N }
86pub const GEMM_BACKWARD_A_SHADER: &str = r#"
87const TILE: u32 = 16u;
88
89@group(0) @binding(0) var<storage, read> grad_c: array<f32>;
90@group(0) @binding(1) var<storage, read> b: array<f32>;
91@group(0) @binding(2) var<storage, read_write> grad_a: array<f32>;
92
93struct Dimensions {
94    M: u32,
95    K: u32,
96    N: u32,
97}
98
99@group(0) @binding(3) var<uniform> dims: Dimensions;
100
101var<workgroup> tile_gc: array<f32, 256>;
102var<workgroup> tile_bt: array<f32, 256>;
103
104@compute @workgroup_size(16, 16)
105fn main(
106    @builtin(global_invocation_id) global_id: vec3<u32>,
107    @builtin(local_invocation_id) local_id: vec3<u32>,
108) {
109    let row = global_id.x;  // M dimension
110    let col = global_id.y;  // K dimension
111    let lr = local_id.x;
112    let lc = local_id.y;
113
114    var acc: f32 = 0.0;
115
116    // Tile over N (reduction dimension for dA = dC @ B^T)
117    let num_tiles = (dims.N + TILE - 1u) / TILE;
118
119    for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
120        // Load tile of grad_c[row, t*TILE + lc]
121        let gc_col = t * TILE + lc;
122        if (row < dims.M && gc_col < dims.N) {
123            tile_gc[lr * TILE + lc] = grad_c[row * dims.N + gc_col];
124        } else {
125            tile_gc[lr * TILE + lc] = 0.0;
126        }
127
128        // Load tile of B^T[t*TILE + lr, col] = B[col, t*TILE + lr]
129        // B is stored as B[K,N] row-major, so B[k,n] = b[k*N + n]
130        // B^T[n,k] = B[k,n] = b[k*N + n]
131        let bt_row = t * TILE + lr;
132        if (col < dims.K && bt_row < dims.N) {
133            tile_bt[lr * TILE + lc] = b[col * dims.N + bt_row];
134        } else {
135            tile_bt[lr * TILE + lc] = 0.0;
136        }
137
138        workgroupBarrier();
139
140        // Accumulate: grad_a[row, col] += sum_n grad_c[row, n] * B^T[n, col]
141        for (var k: u32 = 0u; k < TILE; k = k + 1u) {
142            acc += tile_gc[lr * TILE + k] * tile_bt[k * TILE + lc];
143        }
144
145        workgroupBarrier();
146    }
147
148    if (row < dims.M && col < dims.K) {
149        grad_a[row * dims.K + col] = acc;
150    }
151}
152"#;
153
154/// GEMM backward for B: dL/dB = A^T @ dL/dC
155///
156/// Forward: C[M,N] = A[M,K] @ B[K,N]
157/// Backward: dL/dB[K,N] = A^T[K,M] @ dL/dC[M,N]
158///
159/// Binding 0: a [M*K] (read) — accessed transposed
160/// Binding 1: grad_c (dL/dC) [M*N] (read)
161/// Binding 2: grad_b (dL/dB) [K*N] (write)
162/// Binding 3: uniform { M, K, N }
163pub const GEMM_BACKWARD_B_SHADER: &str = r#"
164const TILE: u32 = 16u;
165
166@group(0) @binding(0) var<storage, read> a: array<f32>;
167@group(0) @binding(1) var<storage, read> grad_c: array<f32>;
168@group(0) @binding(2) var<storage, read_write> grad_b: array<f32>;
169
170struct Dimensions {
171    M: u32,
172    K: u32,
173    N: u32,
174}
175
176@group(0) @binding(3) var<uniform> dims: Dimensions;
177
178var<workgroup> tile_at: array<f32, 256>;
179var<workgroup> tile_gc: array<f32, 256>;
180
181@compute @workgroup_size(16, 16)
182fn main(
183    @builtin(global_invocation_id) global_id: vec3<u32>,
184    @builtin(local_invocation_id) local_id: vec3<u32>,
185) {
186    let row = global_id.x;  // K dimension
187    let col = global_id.y;  // N dimension
188    let lr = local_id.x;
189    let lc = local_id.y;
190
191    var acc: f32 = 0.0;
192
193    // Tile over M (reduction dimension for dB = A^T @ dC)
194    let num_tiles = (dims.M + TILE - 1u) / TILE;
195
196    for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
197        // Load tile of A^T[row, t*TILE + lc] = A[t*TILE + lc, row]
198        let at_col = t * TILE + lc;
199        if (row < dims.K && at_col < dims.M) {
200            tile_at[lr * TILE + lc] = a[at_col * dims.K + row];
201        } else {
202            tile_at[lr * TILE + lc] = 0.0;
203        }
204
205        // Load tile of grad_c[t*TILE + lr, col]
206        let gc_row = t * TILE + lr;
207        if (gc_row < dims.M && col < dims.N) {
208            tile_gc[lr * TILE + lc] = grad_c[gc_row * dims.N + col];
209        } else {
210            tile_gc[lr * TILE + lc] = 0.0;
211        }
212
213        workgroupBarrier();
214
215        for (var k: u32 = 0u; k < TILE; k = k + 1u) {
216            acc += tile_at[lr * TILE + k] * tile_gc[k * TILE + lc];
217        }
218
219        workgroupBarrier();
220    }
221
222    if (row < dims.K && col < dims.N) {
223        grad_b[row * dims.N + col] = acc;
224    }
225}
226"#;
227
228// ============================================================================
229// RMSNorm Backward
230//
231// Forward: y_i = x_i / rms(x) * γ_i, where rms(x) = sqrt(mean(x²) + ε)
232// Backward:
233//   dL/dx_i = (1/rms) * (γ_i * dL/dy_i - x_i/rms² * mean(x · dL/dy · γ))
234//   dL/dγ_i = Σ_batch (dL/dy_i * x_i / rms)
235//
236// Uses workgroup reduction for mean computation (one workgroup per row).
237// ============================================================================
238
239/// RMSNorm backward shader
240///
241/// Binding 0: input x [num_rows * hidden_dim] (read)
242/// Binding 1: gamma [hidden_dim] (read)
243/// Binding 2: grad_output dL/dy [num_rows * hidden_dim] (read)
244/// Binding 3: grad_input dL/dx [num_rows * hidden_dim] (write)
245/// Binding 4: grad_gamma dL/dγ [hidden_dim] (read_write, atomicAdd)
246/// Binding 5: uniform { num_rows, hidden_dim, eps }
247pub const RMSNORM_BACKWARD_SHADER: &str = r#"
248@group(0) @binding(0) var<storage, read> input: array<f32>;
249@group(0) @binding(1) var<storage, read> gamma: array<f32>;
250@group(0) @binding(2) var<storage, read> grad_output: array<f32>;
251@group(0) @binding(3) var<storage, read_write> grad_input: array<f32>;
252@group(0) @binding(4) var<storage, read_write> grad_gamma: array<atomic<u32>>;
253
254struct Params {
255    num_rows: u32,
256    hidden_dim: u32,
257    eps_bits: u32,  // f32 eps reinterpreted as u32 (WGSL uniform limitation)
258    _pad: u32,
259}
260
261@group(0) @binding(5) var<uniform> params: Params;
262
263var<workgroup> shared_sum_x2: array<f32, 256>;
264var<workgroup> shared_sum_xgg: array<f32, 256>;
265
266@compute @workgroup_size(256)
267fn main(
268    @builtin(global_invocation_id) global_id: vec3<u32>,
269    @builtin(local_invocation_id) local_id: vec3<u32>,
270    @builtin(workgroup_id) wg_id: vec3<u32>,
271) {
272    let row = wg_id.x;
273    let tid = local_id.x;
274    let h = params.hidden_dim;
275    let eps = bitcast<f32>(params.eps_bits);
276
277    if (row >= params.num_rows) {
278        return;
279    }
280
281    let row_offset = row * h;
282
283    // Pass 1: Compute sum(x²) and sum(x·dL/dy·γ) via stride loop
284    var local_sum_x2: f32 = 0.0;
285    var local_sum_xgg: f32 = 0.0;
286
287    for (var i = tid; i < h; i = i + 256u) {
288        let x_val = input[row_offset + i];
289        let gy_val = grad_output[row_offset + i];
290        let g_val = gamma[i];
291
292        local_sum_x2 += x_val * x_val;
293        local_sum_xgg += x_val * gy_val * g_val;
294    }
295
296    shared_sum_x2[tid] = local_sum_x2;
297    shared_sum_xgg[tid] = local_sum_xgg;
298    workgroupBarrier();
299
300    // Workgroup reduction (256 → 1)
301    for (var stride = 128u; stride > 0u; stride = stride >> 1u) {
302        if (tid < stride) {
303            shared_sum_x2[tid] += shared_sum_x2[tid + stride];
304            shared_sum_xgg[tid] += shared_sum_xgg[tid + stride];
305        }
306        workgroupBarrier();
307    }
308
309    let sum_x2 = shared_sum_x2[0];
310    let sum_xgg = shared_sum_xgg[0];
311
312    // Compute rms and mean_xgg
313    let h_f32 = f32(h);
314    let mean_x2 = sum_x2 / h_f32;
315    let variance_eps = mean_x2 + eps;
316    let rms = sqrt(variance_eps);
317    let inv_rms = 1.0 / rms;
318    let mean_xgg = sum_xgg / h_f32;
319
320    // Pass 2: Compute and store grad_x, accumulate grad_gamma
321    for (var i = tid; i < h; i = i + 256u) {
322        let x_val = input[row_offset + i];
323        let gy_val = grad_output[row_offset + i];
324        let g_val = gamma[i];
325
326        // grad_x = (1/rms) * (γ * dL/dy - x/var_eps * mean_xgg)
327        let gamma_gy = g_val * gy_val;
328        let correction = (x_val / variance_eps) * mean_xgg;
329        let grad_x = inv_rms * (gamma_gy - correction);
330        grad_input[row_offset + i] = grad_x;
331
332        // grad_gamma[i] += dL/dy * x / rms (accumulated across rows via atomic)
333        let gg_contrib = gy_val * x_val * inv_rms;
334        let gg_bits = bitcast<u32>(gg_contrib);
335        // Atomic float add via CAS loop (WGSL doesn't have native atomicAdd for f32)
336        var old_bits = atomicLoad(&grad_gamma[i]);
337        loop {
338            let old_val = bitcast<f32>(old_bits);
339            let new_val = old_val + gg_contrib;
340            let new_bits = bitcast<u32>(new_val);
341            let result = atomicCompareExchangeWeak(&grad_gamma[i], old_bits, new_bits);
342            if (result.exchanged) {
343                break;
344            }
345            old_bits = result.old_value;
346        }
347    }
348}
349"#;
350
351// ============================================================================
352// RoPE Backward: same rotation but with negated sin
353//
354// Forward: (x_even, x_odd) → (x_even*cos - x_odd*sin, x_even*sin + x_odd*cos)
355// Backward: (dx_even, dx_odd) → (dx_even*cos + dx_odd*sin, -dx_even*sin + dx_odd*cos)
356//
357// The backward is the TRANSPOSE of the forward rotation matrix.
358// ============================================================================
359
360/// RoPE backward shader
361///
362/// Binding 0: grad_output [batch * num_heads * seq_len * head_dim] (read)
363/// Binding 1: grad_input [same shape] (write)
364/// Binding 2: uniform { num_heads, head_dim, seq_len, theta_log2 }
365pub const ROPE_BACKWARD_SHADER: &str = r#"
366@group(0) @binding(0) var<storage, read> grad_output: array<f32>;
367@group(0) @binding(1) var<storage, read_write> grad_input: array<f32>;
368
369struct Params {
370    num_heads: u32,
371    head_dim: u32,
372    seq_len: u32,
373    theta_log2: f32,  // log2(theta), e.g. log2(10000) ≈ 13.29
374}
375
376@group(0) @binding(2) var<uniform> params: Params;
377
378@compute @workgroup_size(256)
379fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
380    let idx = global_id.x + global_id.y * 65535u * 256u;
381    let half_dim = params.head_dim / 2u;
382    let total_pairs = params.num_heads * params.seq_len * half_dim;
383
384    if (idx >= total_pairs) {
385        return;
386    }
387
388    // Decompose idx into (head, pos, pair)
389    let pair = idx % half_dim;
390    let remaining = idx / half_dim;
391    let pos = remaining % params.seq_len;
392    let head = remaining / params.seq_len;
393
394    // Compute rotation angle: θ_i = pos / θ^(2i/d)
395    let freq_exp = -f32(2u * pair) / f32(params.head_dim) * params.theta_log2;
396    let inv_freq = exp2(freq_exp);
397    let angle = f32(pos) * inv_freq;
398    let cos_angle = cos(angle);
399    let sin_angle = sin(angle);
400
401    // Element indices
402    let base = head * params.seq_len * params.head_dim + pos * params.head_dim;
403    let even_idx = base + 2u * pair;
404    let odd_idx = base + 2u * pair + 1u;
405
406    let dy_even = grad_output[even_idx];
407    let dy_odd = grad_output[odd_idx];
408
409    // Backward rotation (transpose of forward):
410    // dx_even = dy_even * cos + dy_odd * sin
411    // dx_odd  = -dy_even * sin + dy_odd * cos
412    grad_input[even_idx] = dy_even * cos_angle + dy_odd * sin_angle;
413    grad_input[odd_idx] = -dy_even * sin_angle + dy_odd * cos_angle;
414}
415"#;
416
417// ============================================================================
418// AdamW Optimizer Step
419//
420// For each parameter:
421//   m = β1 * m + (1 - β1) * grad
422//   v = β2 * v + (1 - β2) * grad²
423//   m_hat = m / (1 - β1^t)
424//   v_hat = v / (1 - β2^t)
425//   param = param - lr * (m_hat / (sqrt(v_hat) + ε) + weight_decay * param)
426// ============================================================================
427
428/// AdamW optimizer step shader
429///
430/// Binding 0: params (read_write)
431/// Binding 1: grads (read)
432/// Binding 2: m (first moment, read_write)
433/// Binding 3: v (second moment, read_write)
434/// Binding 4: uniform { n, lr, beta1, beta2, eps, weight_decay, bc1, bc2 }
435pub const ADAMW_STEP_SHADER: &str = r#"
436@group(0) @binding(0) var<storage, read_write> params: array<f32>;
437@group(0) @binding(1) var<storage, read> grads: array<f32>;
438@group(0) @binding(2) var<storage, read_write> m: array<f32>;
439@group(0) @binding(3) var<storage, read_write> v: array<f32>;
440
441struct AdamWParams {
442    n: u32,
443    lr: f32,
444    beta1: f32,
445    beta2: f32,
446    eps: f32,
447    weight_decay: f32,
448    bias_correction1: f32,  // 1 - β1^t
449    bias_correction2: f32,  // 1 - β2^t
450}
451
452@group(0) @binding(4) var<uniform> hp: AdamWParams;
453
454@compute @workgroup_size(256)
455fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
456    let idx = global_id.x + global_id.y * 65535u * 256u;
457    if (idx >= hp.n) {
458        return;
459    }
460
461    let g = grads[idx];
462
463    // Update moments
464    m[idx] = hp.beta1 * m[idx] + (1.0 - hp.beta1) * g;
465    v[idx] = hp.beta2 * v[idx] + (1.0 - hp.beta2) * g * g;
466
467    // Bias correction
468    let m_hat = m[idx] / hp.bias_correction1;
469    let v_hat = v[idx] / hp.bias_correction2;
470
471    // Weight decay + parameter update
472    let p = params[idx];
473    params[idx] = p - hp.lr * (m_hat / (sqrt(v_hat) + hp.eps) + hp.weight_decay * p);
474}
475"#;
476
477// ============================================================================
478// NF4 Dequantization: 4-bit NormalFloat to fp32
479//
480// Each byte stores two 4-bit values. block_size=64 elements share one scale.
481// Lookup table maps 4-bit index → fp32 value, then multiply by scale.
482// ============================================================================
483
484/// NF4 weight dequantization shader
485///
486/// Binding 0: packed_data [n/2] as u32 (read) — each u32 has 8 nibbles
487/// Binding 1: scales [n/block_size] (read)
488/// Binding 2: output [n] (write)
489/// Binding 3: uniform { n, block_size }
490pub const NF4_DEQUANT_SHADER: &str = r#"
491// NF4 codebook (same as trueno::quantize::NF4_LUT)
492const NF4_LUT: array<f32, 16> = array<f32, 16>(
493    -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
494    -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
495    0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
496    0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
497);
498
499@group(0) @binding(0) var<storage, read> packed: array<u32>;
500@group(0) @binding(1) var<storage, read> scales: array<f32>;
501@group(0) @binding(2) var<storage, read_write> output: array<f32>;
502
503struct Params {
504    n: u32,
505    block_size: u32,
506}
507
508@group(0) @binding(3) var<uniform> params: Params;
509
510@compute @workgroup_size(256)
511fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
512    // 2D dispatch for large tensors (>16M elements): idx = x + y * 65535 * 256
513    let idx = global_id.x + global_id.y * 65535u * 256u;
514    if (idx >= params.n) {
515        return;
516    }
517
518    // Each byte has 2 nibbles: low nibble = even index, high nibble = odd index
519    let byte_idx = idx / 2u;
520    let packed_val = packed[byte_idx / 4u];
521    let byte_in_u32 = byte_idx % 4u;
522    let byte_val = (packed_val >> (byte_in_u32 * 8u)) & 0xFFu;
523
524    var nibble: u32;
525    if (idx % 2u == 0u) {
526        nibble = byte_val & 0xFu;  // low nibble
527    } else {
528        nibble = (byte_val >> 4u) & 0xFu;  // high nibble
529    }
530
531    let scale = scales[idx / params.block_size];
532    output[idx] = NF4_LUT[nibble] * scale;
533}
534"#;
535
536// ============================================================================
537// Fused Cross-Entropy Forward: loss = -log(softmax(logits)[label])
538//
539// One workgroup per token position. Each computes:
540// 1. max(logits) for numerical stability
541// 2. logsumexp = max + log(Σ exp(logit - max))
542// 3. loss = -logits[label] + logsumexp
543//
544// Saves logsumexp per position for backward pass.
545// Response-only masking: positions outside [loss_start, loss_end) contribute 0.
546//
547// Contract: fused-cross-entropy-v1 / fused_forward
548// ============================================================================
549
550/// Fused cross-entropy forward loss shader
551///
552/// Each workgroup computes loss for one token position.
553/// Outputs: losses[pos] and logsumexp[pos] (saved for backward).
554pub const CROSS_ENTROPY_FORWARD_SHADER: &str = r#"
555@group(0) @binding(0) var<storage, read> logits: array<f32>;   // [seq_len, vocab_size]
556@group(0) @binding(1) var<storage, read> labels: array<u32>;   // [seq_len] — target token IDs
557@group(0) @binding(2) var<storage, read_write> losses: array<f32>;     // [seq_len] — per-token loss
558@group(0) @binding(3) var<storage, read_write> logsumexp: array<f32>;  // [seq_len] — saved for backward
559
560struct CEParams {
561    seq_len: u32,
562    vocab_size: u32,
563    loss_start: u32,   // first response token position
564    loss_end: u32,     // last+1 response token position
565}
566
567@group(0) @binding(4) var<uniform> params: CEParams;
568
569@compute @workgroup_size(1)
570fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
571    let pos = gid.x;
572    if (pos >= params.seq_len) { return; }
573
574    // Skip non-response positions
575    if (pos < params.loss_start || pos >= params.loss_end) {
576        losses[pos] = 0.0;
577        logsumexp[pos] = 0.0;
578        return;
579    }
580
581    let offset = pos * params.vocab_size;
582    let label = labels[pos];
583
584    // Pass 1: find max for numerical stability
585    var max_val: f32 = -1e30;
586    for (var v = 0u; v < params.vocab_size; v++) {
587        max_val = max(max_val, logits[offset + v]);
588    }
589
590    // Pass 2: compute sum(exp(logit - max))
591    var sum_exp: f32 = 0.0;
592    for (var v = 0u; v < params.vocab_size; v++) {
593        sum_exp += exp(logits[offset + v] - max_val);
594    }
595
596    let lse = max_val + log(sum_exp);
597    logsumexp[pos] = lse;
598
599    // Cross-entropy loss: -logits[label] + logsumexp
600    if (label < params.vocab_size) {
601        losses[pos] = -logits[offset + label] + lse;
602    } else {
603        losses[pos] = 0.0;  // padding token
604    }
605}
606"#;
607
608// ============================================================================
609// Fused Cross-Entropy Backward: grad_logits = softmax(logits) - one_hot(label)
610//
611// Writes gradient IN-PLACE into the logits buffer (no allocation).
612// Uses saved logsumexp from forward pass.
613//
614// Contract: fused-cross-entropy-v1 / fused_backward
615// ============================================================================
616
617/// Fused cross-entropy backward shader — writes gradient in-place into logits
618pub const CROSS_ENTROPY_BACKWARD_SHADER: &str = r#"
619@group(0) @binding(0) var<storage, read_write> logits: array<f32>; // [seq_len, vocab_size] — overwritten with gradient
620@group(0) @binding(1) var<storage, read> labels: array<u32>;       // [seq_len]
621@group(0) @binding(2) var<storage, read> logsumexp: array<f32>;    // [seq_len] — from forward
622
623struct CEBackParams {
624    seq_len: u32,
625    vocab_size: u32,
626    loss_start: u32,
627    loss_end: u32,
628    scale: f32,    // 1.0 / num_response_tokens
629}
630
631@group(0) @binding(3) var<uniform> params: CEBackParams;
632
633@compute @workgroup_size(256)
634fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
635    // 2D dispatch for large tensors (seq × vocab > 65535 × 256)
636    let idx = gid.x + gid.y * 65535u * 256u;
637    let total = params.seq_len * params.vocab_size;
638    if (idx >= total) { return; }
639
640    let pos = idx / params.vocab_size;
641    let v = idx % params.vocab_size;
642
643    // Zero gradient for non-response positions
644    if (pos < params.loss_start || pos >= params.loss_end) {
645        logits[idx] = 0.0;
646        return;
647    }
648
649    let lse = logsumexp[pos];
650    let logit = logits[idx];
651
652    // softmax(logit) = exp(logit - logsumexp)
653    var grad = exp(logit - lse);
654
655    // Subtract 1 at the label position
656    let label = labels[pos];
657    if (v == label) {
658        grad -= 1.0;
659    }
660
661    // Scale by 1/num_response_tokens
662    logits[idx] = grad * params.scale;
663}
664"#;