Skip to main content

entrenar/train/transformer_trainer/
wgpu_backward.rs

1//! WGPU backward pass through 36 transformer layers + LoRA AdamW
2//!
3//! Backpropagates gradients from lm_head through all FFN layers,
4//! computes LoRA Q/V gradients, and runs AdamW on adapter weights.
5//!
6//! # Contract: C-WGPU-TRAIN-007 (layer backward + LoRA optimizer)
7
8#[cfg(feature = "gpu")]
9use trueno::backends::gpu::GpuDevice;
10
11/// CPU AdamW step — avoids GPU dispatch overhead for small LoRA tensors
12/// LoRA A: rank×hidden = 16×2560 = 40K params → ~0.1ms on CPU vs ~50ms GPU dispatch
13#[cfg(feature = "gpu")]
14fn cpu_adamw(
15    params: &mut [f32],
16    grad: &[f32],
17    m: &mut [f32],
18    v: &mut [f32],
19    lr: f32,
20    beta1: f32,
21    beta2: f32,
22    eps: f32,
23    wd: f32,
24    step: u32,
25) {
26    let bc1 = 1.0 / (1.0 - beta1.powi(step as i32));
27    let bc2 = 1.0 / (1.0 - beta2.powi(step as i32));
28    for i in 0..params.len() {
29        m[i] = beta1 * m[i] + (1.0 - beta1) * grad[i];
30        v[i] = beta2 * v[i] + (1.0 - beta2) * grad[i] * grad[i];
31        let m_hat = m[i] * bc1;
32        let v_hat = v[i] * bc2;
33        params[i] -= lr * (m_hat / (v_hat.sqrt() + eps) + wd * params[i]);
34    }
35}
36
37/// Per-layer forward activations cached for backward pass
38#[cfg(feature = "gpu")]
39pub struct LayerActivations {
40    /// Input to attention (after RMSNorm) [seq_len, hidden_size]
41    pub attn_input: Vec<f32>,
42    /// Input to FFN (after attention + RMSNorm) [seq_len, hidden_size]
43    pub hidden_input: Vec<f32>,
44    /// Gate projection output [seq_len, intermediate_size]
45    pub gate_output: Vec<f32>,
46    /// Up projection output [seq_len, intermediate_size]
47    pub up_output: Vec<f32>,
48    /// SiLU(gate) output [seq_len, intermediate_size]
49    pub silu_gate: Vec<f32>,
50    /// Attention: Q after QK-norm + RoPE [seq_len, q_dim]
51    pub q: Vec<f32>,
52    /// Attention: K after QK-norm + RoPE [seq_len, kv_dim]
53    pub k: Vec<f32>,
54    /// Attention: V [seq_len, kv_dim]
55    pub v: Vec<f32>,
56    /// Attention: softmax weights per head [num_heads, seq_len, seq_len]
57    pub attn_weights: Vec<f32>,
58    /// Attention: context before O projection [seq_len, q_dim]
59    pub context: Vec<f32>,
60    /// h_cached for LoRA Q: hidden @ A_q^T [seq_len, rank]
61    pub lora_q_h: Vec<f32>,
62    /// h_cached for LoRA V: hidden @ A_v^T [seq_len, rank]
63    pub lora_v_h: Vec<f32>,
64}
65
66/// Run backward pass through all layers and update LoRA adapters
67///
68/// Given grad_hidden from lm_head backward, backpropagates through each FFN
69/// layer in reverse order, computing LoRA Q/V gradients and running AdamW.
70///
71/// # Contract (C-WGPU-TRAIN-007)
72/// - Precondition: grad_hidden from lm_head backward is finite
73/// - Postcondition: LoRA Q/V adapters updated, grad_hidden propagated to layer 0
74#[cfg(feature = "gpu")]
75pub fn backward_through_layers(
76    device: &GpuDevice,
77    grad_hidden: &mut Vec<f32>,
78    activations: &[LayerActivations],
79    model: &mut super::wgpu_trainer::WgpuModelState,
80    seq_len: u32,
81    hidden_size: u32,
82    intermediate_size: u32,
83    lr: f32,
84    beta1: f32,
85    beta2: f32,
86    eps: f32,
87    weight_decay: f32,
88    step: u32,
89    lora_alpha: f32,
90) -> Result<f32, String> {
91    let s = seq_len;
92    let h = hidden_size;
93    let i = intermediate_size;
94    let n_layers = model.num_layers;
95    let mut total_lora_gnorm = 0.0f32;
96
97    // Backward through layers in reverse order
98    for layer_idx in (0..n_layers).rev() {
99        let act = &activations[layer_idx];
100
101        // Get cached FFN weights
102        let (gate_w, up_w, down_w) = model.ffn_cache[layer_idx]
103            .as_ref()
104            .map(|(g, u, d)| (g.as_slice(), u.as_slice(), d.as_slice()))
105            .expect("cache populated");
106
107        // --- FFN backward ---
108        // 1. Down backward: grad_swiglu = grad_hidden @ down (GPU GEMM)
109        let mut grad_swiglu = vec![0.0f32; (s * i) as usize];
110        device.gemm_backward_a(grad_hidden, down_w, &mut grad_swiglu, s, i, h)?;
111
112        // 2. SiLU backward
113        let n_inter = (s * i) as usize;
114        let mut grad_gate = vec![0.0f32; n_inter];
115        let mut grad_up = vec![0.0f32; n_inter];
116        for j in 0..n_inter {
117            let x = act.gate_output[j];
118            let sig = 1.0 / (1.0 + (-x).exp());
119            let y = x * sig;
120            let silu_prime = sig * (1.0 + x - y);
121            grad_gate[j] = grad_swiglu[j] * act.up_output[j] * silu_prime;
122            grad_up[j] = grad_swiglu[j] * act.silu_gate[j];
123        }
124
125        // 3. Gate backward: grad_input_gate = grad_gate @ gate^T (GPU GEMM)
126        let mut grad_input_gate = vec![0.0f32; (s * h) as usize];
127        device.gemm_backward_a(&grad_gate, gate_w, &mut grad_input_gate, s, h, i)?;
128
129        // 4. Up backward: grad_input_up = grad_up @ up^T (GPU GEMM)
130        let mut grad_input_up = vec![0.0f32; (s * h) as usize];
131        device.gemm_backward_a(&grad_up, up_w, &mut grad_input_up, s, h, i)?;
132
133        // 5. Sum: grad_ffn_input = grad_input_gate + grad_input_up
134        for j in 0..(s * h) as usize {
135            grad_hidden[j] = grad_input_gate[j] + grad_input_up[j];
136        }
137
138        // --- Attention backward → real LoRA gradients ---
139        // 1. grad_context = grad_hidden @ O^T (O is pre-transposed [q_dim, h])
140        let q_dim = model.num_heads * model.head_dim;
141        let kv_dim = model.num_kv_heads * model.head_dim;
142        let hd = model.head_dim;
143        let nh = model.num_heads;
144        let nkv = model.num_kv_heads;
145        let heads_per_kv = nh / nkv;
146        let (_, _, _, o_w) = model.attn_cache[layer_idx]
147            .as_ref()
148            .map(|(q, k, v, o)| (q.as_slice(), k.as_slice(), v.as_slice(), o.as_slice()))
149            .expect("attn cache");
150        // O is transposed [q_dim, h], so grad_hidden[s,h] @ O^T[h,q_dim]... but we need
151        // grad_context = grad_hidden @ O_original. O_original = O_transposed^T.
152        // gemm_backward_a computes A @ B^T, so: gemm_backward_a(grad_hidden, O_transposed) = grad_hidden @ O_transposed^T = grad_hidden @ O_original
153        let mut grad_context = vec![0.0f32; s as usize * q_dim];
154        device.gemm_backward_a(grad_hidden, o_w, &mut grad_context, s, q_dim as u32, h)?;
155
156        // 2. Attention backward: grad_q, grad_v from grad_context through softmax+V
157        let scale = 1.0 / (hd as f32).sqrt();
158        let mut grad_q = vec![0.0f32; s as usize * q_dim];
159        let mut grad_v = vec![0.0f32; s as usize * kv_dim];
160        for head in 0..nh {
161            let kv_head = head / heads_per_kv;
162            for qi in 0..s as usize {
163                let aw_off = head * s as usize * s as usize + qi * s as usize;
164                // grad_scores[ki] = sum_d(grad_context[qi,head,d] * v[ki,kv_head,d])
165                let mut grad_scores = vec![0.0f32; s as usize];
166                let mut dot_sum = 0.0f32;
167                for ki in 0..s as usize {
168                    for d in 0..hd {
169                        grad_scores[ki] += grad_context[qi * q_dim + head * hd + d]
170                            * act.v[ki * kv_dim + kv_head * hd + d];
171                    }
172                    dot_sum += act.attn_weights[aw_off + ki] * grad_scores[ki];
173                }
174                // Softmax backward: grad_pre = attn_w * (grad_scores - dot_sum)
175                for ki in 0..s as usize {
176                    let g_pre = act.attn_weights[aw_off + ki] * (grad_scores[ki] - dot_sum) * scale;
177                    // grad_q[qi] += g_pre * k[ki]
178                    for d in 0..hd {
179                        grad_q[qi * q_dim + head * hd + d] +=
180                            g_pre * act.k[ki * kv_dim + kv_head * hd + d];
181                    }
182                }
183                // grad_v[ki] += attn_w[qi,ki] * grad_context[qi]
184                for ki in 0..s as usize {
185                    let w = act.attn_weights[aw_off + ki];
186                    if w > 0.0 {
187                        for d in 0..hd {
188                            grad_v[ki * kv_dim + kv_head * hd + d] +=
189                                w * grad_context[qi * q_dim + head * hd + d];
190                        }
191                    }
192                }
193            }
194        }
195
196        // 3. LoRA Q backward: dL/dB = (α/r) * h_cached^T @ grad_q, dL/dA = (α/r) * B^T @ grad_q @ x
197        let rank = model.lora[layer_idx].q.rank as usize;
198        if rank > 0 {
199            let scaling = lora_alpha / rank as f32;
200            // dL/dB_q [q_dim, rank] = scaling * grad_q^T[q_dim, s] @ h_cached[s, rank]
201            let mut grad_b = vec![0.0f32; q_dim * rank];
202            for qi in 0..q_dim {
203                for ri in 0..rank {
204                    let mut sum = 0.0f32;
205                    for si in 0..s as usize {
206                        sum += grad_q[si * q_dim + qi] * act.lora_q_h[si * rank + ri];
207                    }
208                    grad_b[qi * rank + ri] = sum * scaling;
209                }
210            }
211            // dL/dA_q [rank, h] = scaling * (B^T @ grad_q)^T @ x = scaling * grad_q^T @ B → then transpose...
212            // Simpler: dL/dA = scaling * sum_s(grad_h_cached[s,rank] outer x[s,h]) where grad_h_cached = grad_q @ B
213            let mut grad_h_cached = vec![0.0f32; s as usize * rank];
214            for si in 0..s as usize {
215                for ri in 0..rank {
216                    let mut sum = 0.0f32;
217                    for qi in 0..q_dim {
218                        sum += grad_q[si * q_dim + qi] * model.lora[layer_idx].q.b[qi * rank + ri];
219                    }
220                    grad_h_cached[si * rank + ri] = sum * scaling;
221                }
222            }
223            let mut grad_a = vec![0.0f32; rank * h as usize];
224            for ri in 0..rank {
225                for hi in 0..h as usize {
226                    let mut sum = 0.0f32;
227                    for si in 0..s as usize {
228                        sum += grad_h_cached[si * rank + ri] * act.attn_input[si * h as usize + hi];
229                    }
230                    grad_a[ri * h as usize + hi] = sum;
231                }
232            }
233            total_lora_gnorm += grad_a.iter().map(|g| g * g).sum::<f32>();
234            // C-WGPU-LORAPLUS-001: LoRA+ — lr_B = 16 * lr_A (Hayou et al. 2024)
235            let lq = &mut model.lora[layer_idx].q;
236            cpu_adamw(
237                &mut lq.a,
238                &grad_a,
239                &mut lq.m_a,
240                &mut lq.v_a,
241                lr,
242                beta1,
243                beta2,
244                eps,
245                weight_decay,
246                step,
247            );
248            cpu_adamw(
249                &mut lq.b,
250                &grad_b,
251                &mut lq.m_b,
252                &mut lq.v_b,
253                lr * 16.0,
254                beta1,
255                beta2,
256                eps,
257                weight_decay,
258                step,
259            );
260        }
261
262        // 4. LoRA V backward: same pattern with grad_v
263        let v_rank = model.lora[layer_idx].v.rank as usize;
264        if v_rank > 0 {
265            let scaling = lora_alpha / v_rank as f32;
266            let mut grad_b = vec![0.0f32; kv_dim * v_rank];
267            for vi in 0..kv_dim {
268                for ri in 0..v_rank {
269                    let mut sum = 0.0f32;
270                    for si in 0..s as usize {
271                        sum += grad_v[si * kv_dim + vi] * act.lora_v_h[si * v_rank + ri];
272                    }
273                    grad_b[vi * v_rank + ri] = sum * scaling;
274                }
275            }
276            let mut grad_h_cached = vec![0.0f32; s as usize * v_rank];
277            for si in 0..s as usize {
278                for ri in 0..v_rank {
279                    let mut sum = 0.0f32;
280                    for vi in 0..kv_dim {
281                        sum +=
282                            grad_v[si * kv_dim + vi] * model.lora[layer_idx].v.b[vi * v_rank + ri];
283                    }
284                    grad_h_cached[si * v_rank + ri] = sum * scaling;
285                }
286            }
287            let mut grad_a = vec![0.0f32; v_rank * h as usize];
288            for ri in 0..v_rank {
289                for hi in 0..h as usize {
290                    let mut sum = 0.0f32;
291                    for si in 0..s as usize {
292                        sum +=
293                            grad_h_cached[si * v_rank + ri] * act.attn_input[si * h as usize + hi];
294                    }
295                    grad_a[ri * h as usize + hi] = sum;
296                }
297            }
298            total_lora_gnorm += grad_a.iter().map(|g| g * g).sum::<f32>();
299            let lv = &mut model.lora[layer_idx].v;
300            cpu_adamw(
301                &mut lv.a,
302                &grad_a,
303                &mut lv.m_a,
304                &mut lv.v_a,
305                lr,
306                beta1,
307                beta2,
308                eps,
309                weight_decay,
310                step,
311            );
312            cpu_adamw(
313                &mut lv.b,
314                &grad_b,
315                &mut lv.m_b,
316                &mut lv.v_b,
317                lr * 16.0,
318                beta1,
319                beta2,
320                eps,
321                weight_decay,
322                step,
323            );
324        }
325    }
326
327    Ok(total_lora_gnorm.sqrt())
328}
329
330#[cfg(all(test, feature = "gpu"))]
331mod tests {
332    use super::*;
333
334    /// FALSIFY: Backward through layers produces non-zero LoRA gradient norm
335    #[test]
336    fn test_backward_through_layers_gradient_flow() {
337        // Create minimal model state
338        let rank = 4u32;
339        let h = 8u32;
340        let i_size = 16u32;
341        let s = 2u32;
342        let n_layers = 2;
343
344        let device = GpuDevice::new().expect("GPU");
345
346        let mut model = super::super::wgpu_trainer::WgpuModelState {
347            layers: vec![],
348            lora: (0..n_layers)
349                .map(|_| {
350                    crate::train::transformer_trainer::wgpu_checkpoint::LoraLayerSet::new(
351                        rank, h, h, h, i_size,
352                    )
353                })
354                .collect(),
355            lm_head: vec![0.0f32; 32 * h as usize],
356            lm_head_m: vec![0.0f32; 32 * h as usize],
357            lm_head_v: vec![0.0f32; 32 * h as usize],
358            hidden_size: h as usize,
359            num_layers: n_layers,
360            vocab_size: 32,
361            num_heads: 2,
362            num_kv_heads: 2,
363            head_dim: 4,
364            intermediate_size: i_size as usize,
365            ffn_cache: vec![None; n_layers],
366            attn_cache: vec![None; n_layers],
367        };
368
369        for l in 0..n_layers {
370            model.ffn_cache[l] = Some((
371                vec![0.01f32; (i_size * h) as usize],
372                vec![0.01f32; (i_size * h) as usize],
373                vec![0.01f32; (h * i_size) as usize],
374            ));
375            model.attn_cache[l] = Some((
376                vec![0.01f32; h as usize * 8], // q [h, q_dim]
377                vec![0.01f32; h as usize * 8], // k [h, kv_dim]
378                vec![0.01f32; h as usize * 8], // v [h, kv_dim]
379                vec![0.01f32; 8 * h as usize], // o [q_dim, h]
380            ));
381        }
382
383        // Activations
384        let q_dim = 8usize; // 2 heads * 4 head_dim
385        let kv_dim = 8usize; // 2 kv_heads * 4 head_dim
386        let activations: Vec<LayerActivations> = (0..n_layers)
387            .map(|_| LayerActivations {
388                attn_input: (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.1).collect(),
389                hidden_input: (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.1).collect(),
390                gate_output: vec![0.5f32; (s * i_size) as usize],
391                up_output: vec![0.3f32; (s * i_size) as usize],
392                silu_gate: vec![0.25f32; (s * i_size) as usize],
393                q: vec![0.1f32; s as usize * q_dim],
394                k: vec![0.1f32; s as usize * kv_dim],
395                v: vec![0.1f32; s as usize * kv_dim],
396                attn_weights: vec![0.5f32; 2 * s as usize * s as usize], // 2 heads
397                context: vec![0.1f32; s as usize * q_dim],
398                lora_q_h: vec![0.01f32; s as usize * rank as usize],
399                lora_v_h: vec![0.01f32; s as usize * rank as usize],
400            })
401            .collect();
402
403        let mut grad_hidden: Vec<f32> =
404            (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.01).collect();
405
406        // Save original LoRA weights for comparison
407        let orig_q_a_0 = model.lora[0].q.a.clone();
408        let orig_v_a_0 = model.lora[0].v.a.clone();
409
410        let gnorm = backward_through_layers(
411            &device,
412            &mut grad_hidden,
413            &activations,
414            &mut model,
415            s,
416            h,
417            i_size,
418            1e-3,
419            0.9,
420            0.999,
421            1e-8,
422            0.01,
423            1,
424            32.0,
425        )
426        .expect("backward");
427
428        // LoRA weights must have changed
429        assert_ne!(model.lora[0].q.a, orig_q_a_0, "LoRA Q adapter A must be updated");
430        assert_ne!(model.lora[0].v.a, orig_v_a_0, "LoRA V adapter A must be updated");
431        assert!(gnorm >= 0.0, "Gradient norm must be non-negative");
432        assert!(grad_hidden.iter().all(|g| g.is_finite()), "All gradients finite");
433
434        eprintln!("Backward through {n_layers} layers: lora_gnorm={gnorm:.6}");
435    }
436}