1#[cfg(feature = "gpu")]
9use trueno::backends::gpu::GpuDevice;
10
11#[cfg(feature = "gpu")]
14fn cpu_adamw(
15 params: &mut [f32],
16 grad: &[f32],
17 m: &mut [f32],
18 v: &mut [f32],
19 lr: f32,
20 beta1: f32,
21 beta2: f32,
22 eps: f32,
23 wd: f32,
24 step: u32,
25) {
26 let bc1 = 1.0 / (1.0 - beta1.powi(step as i32));
27 let bc2 = 1.0 / (1.0 - beta2.powi(step as i32));
28 for i in 0..params.len() {
29 m[i] = beta1 * m[i] + (1.0 - beta1) * grad[i];
30 v[i] = beta2 * v[i] + (1.0 - beta2) * grad[i] * grad[i];
31 let m_hat = m[i] * bc1;
32 let v_hat = v[i] * bc2;
33 params[i] -= lr * (m_hat / (v_hat.sqrt() + eps) + wd * params[i]);
34 }
35}
36
37#[cfg(feature = "gpu")]
39pub struct LayerActivations {
40 pub attn_input: Vec<f32>,
42 pub hidden_input: Vec<f32>,
44 pub gate_output: Vec<f32>,
46 pub up_output: Vec<f32>,
48 pub silu_gate: Vec<f32>,
50 pub q: Vec<f32>,
52 pub k: Vec<f32>,
54 pub v: Vec<f32>,
56 pub attn_weights: Vec<f32>,
58 pub context: Vec<f32>,
60 pub lora_q_h: Vec<f32>,
62 pub lora_v_h: Vec<f32>,
64}
65
66#[cfg(feature = "gpu")]
75pub fn backward_through_layers(
76 device: &GpuDevice,
77 grad_hidden: &mut Vec<f32>,
78 activations: &[LayerActivations],
79 model: &mut super::wgpu_trainer::WgpuModelState,
80 seq_len: u32,
81 hidden_size: u32,
82 intermediate_size: u32,
83 lr: f32,
84 beta1: f32,
85 beta2: f32,
86 eps: f32,
87 weight_decay: f32,
88 step: u32,
89 lora_alpha: f32,
90) -> Result<f32, String> {
91 let s = seq_len;
92 let h = hidden_size;
93 let i = intermediate_size;
94 let n_layers = model.num_layers;
95 let mut total_lora_gnorm = 0.0f32;
96
97 for layer_idx in (0..n_layers).rev() {
99 let act = &activations[layer_idx];
100
101 let (gate_w, up_w, down_w) = model.ffn_cache[layer_idx]
103 .as_ref()
104 .map(|(g, u, d)| (g.as_slice(), u.as_slice(), d.as_slice()))
105 .expect("cache populated");
106
107 let mut grad_swiglu = vec![0.0f32; (s * i) as usize];
110 device.gemm_backward_a(grad_hidden, down_w, &mut grad_swiglu, s, i, h)?;
111
112 let n_inter = (s * i) as usize;
114 let mut grad_gate = vec![0.0f32; n_inter];
115 let mut grad_up = vec![0.0f32; n_inter];
116 for j in 0..n_inter {
117 let x = act.gate_output[j];
118 let sig = 1.0 / (1.0 + (-x).exp());
119 let y = x * sig;
120 let silu_prime = sig * (1.0 + x - y);
121 grad_gate[j] = grad_swiglu[j] * act.up_output[j] * silu_prime;
122 grad_up[j] = grad_swiglu[j] * act.silu_gate[j];
123 }
124
125 let mut grad_input_gate = vec![0.0f32; (s * h) as usize];
127 device.gemm_backward_a(&grad_gate, gate_w, &mut grad_input_gate, s, h, i)?;
128
129 let mut grad_input_up = vec![0.0f32; (s * h) as usize];
131 device.gemm_backward_a(&grad_up, up_w, &mut grad_input_up, s, h, i)?;
132
133 for j in 0..(s * h) as usize {
135 grad_hidden[j] = grad_input_gate[j] + grad_input_up[j];
136 }
137
138 let q_dim = model.num_heads * model.head_dim;
141 let kv_dim = model.num_kv_heads * model.head_dim;
142 let hd = model.head_dim;
143 let nh = model.num_heads;
144 let nkv = model.num_kv_heads;
145 let heads_per_kv = nh / nkv;
146 let (_, _, _, o_w) = model.attn_cache[layer_idx]
147 .as_ref()
148 .map(|(q, k, v, o)| (q.as_slice(), k.as_slice(), v.as_slice(), o.as_slice()))
149 .expect("attn cache");
150 let mut grad_context = vec![0.0f32; s as usize * q_dim];
154 device.gemm_backward_a(grad_hidden, o_w, &mut grad_context, s, q_dim as u32, h)?;
155
156 let scale = 1.0 / (hd as f32).sqrt();
158 let mut grad_q = vec![0.0f32; s as usize * q_dim];
159 let mut grad_v = vec![0.0f32; s as usize * kv_dim];
160 for head in 0..nh {
161 let kv_head = head / heads_per_kv;
162 for qi in 0..s as usize {
163 let aw_off = head * s as usize * s as usize + qi * s as usize;
164 let mut grad_scores = vec![0.0f32; s as usize];
166 let mut dot_sum = 0.0f32;
167 for ki in 0..s as usize {
168 for d in 0..hd {
169 grad_scores[ki] += grad_context[qi * q_dim + head * hd + d]
170 * act.v[ki * kv_dim + kv_head * hd + d];
171 }
172 dot_sum += act.attn_weights[aw_off + ki] * grad_scores[ki];
173 }
174 for ki in 0..s as usize {
176 let g_pre = act.attn_weights[aw_off + ki] * (grad_scores[ki] - dot_sum) * scale;
177 for d in 0..hd {
179 grad_q[qi * q_dim + head * hd + d] +=
180 g_pre * act.k[ki * kv_dim + kv_head * hd + d];
181 }
182 }
183 for ki in 0..s as usize {
185 let w = act.attn_weights[aw_off + ki];
186 if w > 0.0 {
187 for d in 0..hd {
188 grad_v[ki * kv_dim + kv_head * hd + d] +=
189 w * grad_context[qi * q_dim + head * hd + d];
190 }
191 }
192 }
193 }
194 }
195
196 let rank = model.lora[layer_idx].q.rank as usize;
198 if rank > 0 {
199 let scaling = lora_alpha / rank as f32;
200 let mut grad_b = vec![0.0f32; q_dim * rank];
202 for qi in 0..q_dim {
203 for ri in 0..rank {
204 let mut sum = 0.0f32;
205 for si in 0..s as usize {
206 sum += grad_q[si * q_dim + qi] * act.lora_q_h[si * rank + ri];
207 }
208 grad_b[qi * rank + ri] = sum * scaling;
209 }
210 }
211 let mut grad_h_cached = vec![0.0f32; s as usize * rank];
214 for si in 0..s as usize {
215 for ri in 0..rank {
216 let mut sum = 0.0f32;
217 for qi in 0..q_dim {
218 sum += grad_q[si * q_dim + qi] * model.lora[layer_idx].q.b[qi * rank + ri];
219 }
220 grad_h_cached[si * rank + ri] = sum * scaling;
221 }
222 }
223 let mut grad_a = vec![0.0f32; rank * h as usize];
224 for ri in 0..rank {
225 for hi in 0..h as usize {
226 let mut sum = 0.0f32;
227 for si in 0..s as usize {
228 sum += grad_h_cached[si * rank + ri] * act.attn_input[si * h as usize + hi];
229 }
230 grad_a[ri * h as usize + hi] = sum;
231 }
232 }
233 total_lora_gnorm += grad_a.iter().map(|g| g * g).sum::<f32>();
234 let lq = &mut model.lora[layer_idx].q;
236 cpu_adamw(
237 &mut lq.a,
238 &grad_a,
239 &mut lq.m_a,
240 &mut lq.v_a,
241 lr,
242 beta1,
243 beta2,
244 eps,
245 weight_decay,
246 step,
247 );
248 cpu_adamw(
249 &mut lq.b,
250 &grad_b,
251 &mut lq.m_b,
252 &mut lq.v_b,
253 lr * 16.0,
254 beta1,
255 beta2,
256 eps,
257 weight_decay,
258 step,
259 );
260 }
261
262 let v_rank = model.lora[layer_idx].v.rank as usize;
264 if v_rank > 0 {
265 let scaling = lora_alpha / v_rank as f32;
266 let mut grad_b = vec![0.0f32; kv_dim * v_rank];
267 for vi in 0..kv_dim {
268 for ri in 0..v_rank {
269 let mut sum = 0.0f32;
270 for si in 0..s as usize {
271 sum += grad_v[si * kv_dim + vi] * act.lora_v_h[si * v_rank + ri];
272 }
273 grad_b[vi * v_rank + ri] = sum * scaling;
274 }
275 }
276 let mut grad_h_cached = vec![0.0f32; s as usize * v_rank];
277 for si in 0..s as usize {
278 for ri in 0..v_rank {
279 let mut sum = 0.0f32;
280 for vi in 0..kv_dim {
281 sum +=
282 grad_v[si * kv_dim + vi] * model.lora[layer_idx].v.b[vi * v_rank + ri];
283 }
284 grad_h_cached[si * v_rank + ri] = sum * scaling;
285 }
286 }
287 let mut grad_a = vec![0.0f32; v_rank * h as usize];
288 for ri in 0..v_rank {
289 for hi in 0..h as usize {
290 let mut sum = 0.0f32;
291 for si in 0..s as usize {
292 sum +=
293 grad_h_cached[si * v_rank + ri] * act.attn_input[si * h as usize + hi];
294 }
295 grad_a[ri * h as usize + hi] = sum;
296 }
297 }
298 total_lora_gnorm += grad_a.iter().map(|g| g * g).sum::<f32>();
299 let lv = &mut model.lora[layer_idx].v;
300 cpu_adamw(
301 &mut lv.a,
302 &grad_a,
303 &mut lv.m_a,
304 &mut lv.v_a,
305 lr,
306 beta1,
307 beta2,
308 eps,
309 weight_decay,
310 step,
311 );
312 cpu_adamw(
313 &mut lv.b,
314 &grad_b,
315 &mut lv.m_b,
316 &mut lv.v_b,
317 lr * 16.0,
318 beta1,
319 beta2,
320 eps,
321 weight_decay,
322 step,
323 );
324 }
325 }
326
327 Ok(total_lora_gnorm.sqrt())
328}
329
330#[cfg(all(test, feature = "gpu"))]
331mod tests {
332 use super::*;
333
334 #[test]
336 fn test_backward_through_layers_gradient_flow() {
337 let rank = 4u32;
339 let h = 8u32;
340 let i_size = 16u32;
341 let s = 2u32;
342 let n_layers = 2;
343
344 let device = GpuDevice::new().expect("GPU");
345
346 let mut model = super::super::wgpu_trainer::WgpuModelState {
347 layers: vec![],
348 lora: (0..n_layers)
349 .map(|_| {
350 crate::train::transformer_trainer::wgpu_checkpoint::LoraLayerSet::new(
351 rank, h, h, h, i_size,
352 )
353 })
354 .collect(),
355 lm_head: vec![0.0f32; 32 * h as usize],
356 lm_head_m: vec![0.0f32; 32 * h as usize],
357 lm_head_v: vec![0.0f32; 32 * h as usize],
358 hidden_size: h as usize,
359 num_layers: n_layers,
360 vocab_size: 32,
361 num_heads: 2,
362 num_kv_heads: 2,
363 head_dim: 4,
364 intermediate_size: i_size as usize,
365 ffn_cache: vec![None; n_layers],
366 attn_cache: vec![None; n_layers],
367 };
368
369 for l in 0..n_layers {
370 model.ffn_cache[l] = Some((
371 vec![0.01f32; (i_size * h) as usize],
372 vec![0.01f32; (i_size * h) as usize],
373 vec![0.01f32; (h * i_size) as usize],
374 ));
375 model.attn_cache[l] = Some((
376 vec![0.01f32; h as usize * 8], vec![0.01f32; h as usize * 8], vec![0.01f32; h as usize * 8], vec![0.01f32; 8 * h as usize], ));
381 }
382
383 let q_dim = 8usize; let kv_dim = 8usize; let activations: Vec<LayerActivations> = (0..n_layers)
387 .map(|_| LayerActivations {
388 attn_input: (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.1).collect(),
389 hidden_input: (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.1).collect(),
390 gate_output: vec![0.5f32; (s * i_size) as usize],
391 up_output: vec![0.3f32; (s * i_size) as usize],
392 silu_gate: vec![0.25f32; (s * i_size) as usize],
393 q: vec![0.1f32; s as usize * q_dim],
394 k: vec![0.1f32; s as usize * kv_dim],
395 v: vec![0.1f32; s as usize * kv_dim],
396 attn_weights: vec![0.5f32; 2 * s as usize * s as usize], context: vec![0.1f32; s as usize * q_dim],
398 lora_q_h: vec![0.01f32; s as usize * rank as usize],
399 lora_v_h: vec![0.01f32; s as usize * rank as usize],
400 })
401 .collect();
402
403 let mut grad_hidden: Vec<f32> =
404 (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.01).collect();
405
406 let orig_q_a_0 = model.lora[0].q.a.clone();
408 let orig_v_a_0 = model.lora[0].v.a.clone();
409
410 let gnorm = backward_through_layers(
411 &device,
412 &mut grad_hidden,
413 &activations,
414 &mut model,
415 s,
416 h,
417 i_size,
418 1e-3,
419 0.9,
420 0.999,
421 1e-8,
422 0.01,
423 1,
424 32.0,
425 )
426 .expect("backward");
427
428 assert_ne!(model.lora[0].q.a, orig_q_a_0, "LoRA Q adapter A must be updated");
430 assert_ne!(model.lora[0].v.a, orig_v_a_0, "LoRA V adapter A must be updated");
431 assert!(gnorm >= 0.0, "Gradient norm must be non-negative");
432 assert!(grad_hidden.iter().all(|g| g.is_finite()), "All gradients finite");
433
434 eprintln!("Backward through {n_layers} layers: lora_gnorm={gnorm:.6}");
435 }
436}