Skip to main content

entrenar/train/transformer_trainer/
wgpu_trainer.rs

1//! WGPU-accelerated transformer trainer for non-NVIDIA GPUs (AMD/Intel/Apple).
2//!
3//! # Contract: wgpu-training-v1.yaml (FALSIFY-WGPU-002)
4//! - Loss decreases by >50% within 100 steps on toy data
5//! - Gradients flow through all ops (no zero gradients after step 1)
6
7#[cfg(feature = "gpu")]
8use crate::transformer::wgpu_block::WgpuForwardPass;
9#[cfg(feature = "gpu")]
10use crate::transformer::TransformerConfig;
11#[cfg(feature = "gpu")]
12use trueno::backends::gpu::GpuDevice;
13
14/// Transpose [rows, cols] → [cols, rows]. One-time cost during cache population.
15#[cfg(feature = "gpu")]
16fn transpose(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
17    let mut o = vec![0.0f32; rows * cols];
18    for r in 0..rows {
19        for c in 0..cols {
20            o[c * rows + r] = data[r * cols + c];
21        }
22    }
23    o
24}
25
26#[cfg(feature = "gpu")]
27pub struct WgpuTransformerTrainer {
28    forward: WgpuForwardPass,
29    device: GpuDevice,
30    config: TransformerConfig,
31    step: u32,
32    lr: f32,
33    beta1: f32,
34    beta2: f32,
35    eps: f32,
36    weight_decay: f32,
37    lora_rank: u32,
38    lora_alpha: f32,
39}
40
41/// Full model state for WGPU training
42///
43/// Holds NF4 weights + LoRA adapters for all layers.
44/// Per-layer dequant strategy: only one layer's fp32 weights in VRAM at a time.
45#[cfg(feature = "gpu")]
46pub struct WgpuModelState {
47    /// NF4 weights per layer (compact, stays in CPU RAM)
48    pub layers: Vec<super::wgpu_nf4::Nf4LayerWeights>,
49    /// LoRA Q adapters per layer (trainable, fp32)
50    /// LoRA adapters per layer (7 projections: Q/K/V/O/gate/up/down)
51    pub lora: Vec<super::wgpu_checkpoint::LoraLayerSet>,
52    /// LM head weight [vocab_size, hidden_size] fp32
53    pub lm_head: Vec<f32>,
54    /// LM head optimizer state
55    pub lm_head_m: Vec<f32>,
56    pub lm_head_v: Vec<f32>,
57    /// Config
58    pub hidden_size: usize,
59    pub num_layers: usize,
60    pub vocab_size: usize,
61    pub num_heads: usize,
62    pub num_kv_heads: usize,
63    pub head_dim: usize,
64    pub intermediate_size: usize,
65    /// Cached dequanted FFN weights per layer: (gate, up, down) fp32
66    pub ffn_cache: Vec<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>>,
67    pub attn_cache: Vec<Option<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)>>,
68}
69
70#[cfg(feature = "gpu")]
71impl WgpuModelState {
72    /// Load Qwen3-4B model from safetensors directory
73    ///
74    /// Quantizes all weights to NF4 (stays in CPU RAM).
75    /// Creates LoRA adapters for Q and V projections.
76    ///
77    /// # Contract (C-WGPU-TRAIN-003)
78    pub fn load_qwen3_4b(
79        model_dir: &std::path::Path,
80        lora_rank: u32,
81        _lora_alpha: f32,
82    ) -> Result<Self, String> {
83        use std::fs;
84
85        let config_path = model_dir.join("config.json");
86        let config_str = fs::read_to_string(&config_path)
87            .map_err(|e| format!("Cannot read config.json: {e}"))?;
88        let config: serde_json::Value =
89            serde_json::from_str(&config_str).map_err(|e| format!("Invalid config.json: {e}"))?;
90
91        let hidden_size = config["hidden_size"].as_u64().unwrap_or(2560) as usize;
92        let num_layers = config["num_hidden_layers"].as_u64().unwrap_or(36) as usize;
93        let num_heads = config["num_attention_heads"].as_u64().unwrap_or(32) as usize;
94        let num_kv_heads = config["num_key_value_heads"].as_u64().unwrap_or(8) as usize;
95        let intermediate_size = config["intermediate_size"].as_u64().unwrap_or(9728) as usize;
96        let vocab_size = config["vocab_size"].as_u64().unwrap_or(151936) as usize;
97        let head_dim = config["head_dim"].as_u64().unwrap_or(128) as usize;
98
99        eprintln!("Loading Qwen3-4B: {num_layers} layers, h={hidden_size}, i={intermediate_size}");
100
101        // Find safetensors shards
102        let mut shards: Vec<String> = fs::read_dir(model_dir)
103            .map_err(|e| format!("Cannot read model dir: {e}"))?
104            .filter_map(std::result::Result::ok)
105            .map(|e| e.file_name().to_string_lossy().to_string())
106            .filter(|n| n.ends_with(".safetensors"))
107            .collect();
108        shards.sort();
109
110        if shards.is_empty() {
111            return Err("No .safetensors files found".to_string());
112        }
113
114        // Load all shards into memory
115        let mut all_data: Vec<Vec<u8>> = Vec::new();
116        for shard in &shards {
117            let path = model_dir.join(shard);
118            eprintln!("  Loading {shard}...");
119            let data = fs::read(&path).map_err(|e| format!("Cannot read {shard}: {e}"))?;
120            all_data.push(data);
121        }
122
123        // Parse all shards upfront
124        let parsed: Vec<safetensors::SafeTensors<'_>> = all_data
125            .iter()
126            .map(|d| safetensors::SafeTensors::deserialize(d))
127            .collect::<Result<Vec<_>, _>>()
128            .map_err(|e| format!("Deserialize error: {e}"))?;
129
130        // Load each layer — projections may be split across shards
131        let mut layers = Vec::with_capacity(num_layers);
132        let q_dim = num_heads * head_dim;
133        let block_size = 64u32;
134
135        for layer_idx in 0..num_layers {
136            let prefix = format!("model.layers.{layer_idx}");
137
138            // Helper: find tensor across all shards
139            let find_and_quantize = |name: &str,
140                                     rows: usize,
141                                     cols: usize|
142             -> Result<(Vec<u32>, Vec<f32>, u32), String> {
143                for tensors in &parsed {
144                    if tensors.tensor(name).is_ok() {
145                        return super::wgpu_nf4::Nf4LayerWeights::quantize_projection_from_tensors(
146                            tensors, name, rows, cols,
147                        );
148                    }
149                }
150                Err(format!("Tensor {name} not found in any shard"))
151            };
152
153            let kv_dim = num_kv_heads * head_dim;
154            let (gate_p, gate_s, gate_n) = find_and_quantize(
155                &format!("{prefix}.mlp.gate_proj.weight"),
156                intermediate_size,
157                hidden_size,
158            )?;
159            let (up_p, up_s, up_n) = find_and_quantize(
160                &format!("{prefix}.mlp.up_proj.weight"),
161                intermediate_size,
162                hidden_size,
163            )?;
164            let (down_p, down_s, down_n) = find_and_quantize(
165                &format!("{prefix}.mlp.down_proj.weight"),
166                hidden_size,
167                intermediate_size,
168            )?;
169            let (q_p, q_s, q_n) = find_and_quantize(
170                &format!("{prefix}.self_attn.q_proj.weight"),
171                q_dim,
172                hidden_size,
173            )?;
174            let (k_p, k_s, k_n) = find_and_quantize(
175                &format!("{prefix}.self_attn.k_proj.weight"),
176                kv_dim,
177                hidden_size,
178            )?;
179            let (v_p, v_s, v_n) = find_and_quantize(
180                &format!("{prefix}.self_attn.v_proj.weight"),
181                kv_dim,
182                hidden_size,
183            )?;
184            let (o_p, o_s, o_n) = find_and_quantize(
185                &format!("{prefix}.self_attn.o_proj.weight"),
186                hidden_size,
187                q_dim,
188            )?;
189
190            let layer = super::wgpu_nf4::Nf4LayerWeights {
191                gate_packed: gate_p,
192                gate_scales: gate_s,
193                up_packed: up_p,
194                up_scales: up_s,
195                down_packed: down_p,
196                down_scales: down_s,
197                q_packed: q_p,
198                q_scales: q_s,
199                k_packed: k_p,
200                k_scales: k_s,
201                v_packed: v_p,
202                v_scales: v_s,
203                o_packed: o_p,
204                o_scales: o_s,
205                gate_n,
206                up_n,
207                down_n,
208                q_n,
209                k_n,
210                v_n,
211                o_n,
212                block_size,
213            };
214
215            let mb = layer.memory_bytes() as f64 / 1024.0 / 1024.0;
216            if layer_idx % 6 == 0 || layer_idx == num_layers - 1 {
217                eprintln!("  Layer {layer_idx}: {mb:.1} MB NF4");
218            }
219            layers.push(layer);
220        }
221
222        // Create LoRA adapters for Q and V
223        let mut lora = Vec::with_capacity(num_layers);
224        for _ in 0..num_layers {
225            lora.push(super::wgpu_checkpoint::LoraLayerSet::new(
226                lora_rank,
227                hidden_size as u32,
228                q_dim as u32,
229                (num_kv_heads * head_dim) as u32,
230                intermediate_size as u32,
231            ));
232        }
233
234        // LM head: load from last shard
235        let last_data = all_data.last().ok_or("No shards")?;
236        let _tensors = safetensors::SafeTensors::deserialize(last_data)
237            .map_err(|e| format!("Deserialize: {e}"))?;
238
239        // Qwen3 uses tied embeddings: lm_head = embed_tokens
240        let mut lm_head_view = None;
241        for data in &all_data {
242            let t = safetensors::SafeTensors::deserialize(data)
243                .map_err(|e| format!("Deserialize: {e}"))?;
244            for name in ["lm_head.weight", "model.lm_head.weight", "model.embed_tokens.weight"] {
245                if let Ok(v) = t.tensor(name) {
246                    // Need to copy since t borrows data
247                    let fp32: Vec<f32> = match v.dtype() {
248                        safetensors::Dtype::F16 => v
249                            .data()
250                            .chunks_exact(2)
251                            .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
252                            .collect(),
253                        safetensors::Dtype::BF16 => v
254                            .data()
255                            .chunks_exact(2)
256                            .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
257                            .collect(),
258                        _ => bytemuck::cast_slice(v.data()).to_vec(),
259                    };
260                    eprintln!("  LM head from {name}: {} elements", fp32.len());
261                    lm_head_view = Some(fp32);
262                    break;
263                }
264            }
265            if lm_head_view.is_some() {
266                break;
267            }
268        }
269        let lm_head = lm_head_view.ok_or("lm_head/embed_tokens not found in any shard")?;
270        let lm_head_len = lm_head.len();
271        let lora_params: usize =
272            lora.iter().map(super::wgpu_checkpoint::LoraLayerSet::num_params).sum();
273        eprintln!("  LoRA params: {lora_params} (rank={lora_rank}, 7 modules/layer)");
274        eprintln!(
275            "  LM head: {} elements ({:.1} MB)",
276            lm_head_len,
277            lm_head_len as f64 * 4.0 / 1024.0 / 1024.0
278        );
279        Ok(Self {
280            layers,
281            lora,
282            lm_head,
283            lm_head_m: vec![0.0f32; lm_head_len],
284            lm_head_v: vec![0.0f32; lm_head_len],
285            hidden_size,
286            num_layers,
287            vocab_size,
288            num_heads,
289            num_kv_heads,
290            head_dim,
291            intermediate_size,
292            ffn_cache: vec![None; num_layers],
293            attn_cache: vec![None; num_layers],
294        })
295    }
296
297    /// Populate weight caches (pre-transposed for standard matmul).
298    pub fn populate_weight_cache(
299        &mut self,
300        device: &trueno::backends::gpu::GpuDevice,
301    ) -> Result<(), String> {
302        let (h, i) = (self.hidden_size, self.intermediate_size);
303        let (qd, kvd) = (self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim);
304        for li in 0..self.num_layers {
305            let layer = &self.layers[li];
306            if self.ffn_cache[li].is_none() {
307                self.ffn_cache[li] = Some((
308                    transpose(&layer.dequant_gate(device)?, i, h),
309                    transpose(&layer.dequant_up(device)?, i, h),
310                    transpose(&layer.dequant_down(device)?, h, i),
311                ));
312            }
313            if self.attn_cache[li].is_none() {
314                self.attn_cache[li] = Some((
315                    transpose(&layer.dequant_q(device)?, qd, h),
316                    transpose(&layer.dequant_k(device)?, kvd, h),
317                    transpose(&layer.dequant_v(device)?, kvd, h),
318                    transpose(&layer.dequant_o(device)?, h, qd),
319                ));
320                if li % 12 == 0 || li == self.num_layers - 1 {
321                    eprintln!("  Cached layer {li}");
322                }
323            }
324        }
325        Ok(())
326    }
327
328    /// Total trainable parameters
329    pub fn trainable_params(&self) -> usize {
330        self.lora.iter().map(super::wgpu_checkpoint::LoraLayerSet::num_params).sum::<usize>()
331            + self.lm_head.len()
332    }
333    pub fn save_checkpoint(
334        &self,
335        dir: &std::path::Path,
336        step: u32,
337        loss: f32,
338        rank: u32,
339        alpha: f32,
340    ) -> Result<std::path::PathBuf, String> {
341        contract_pre_save_checkpoint!();
342        let result = super::wgpu_checkpoint::save_lora_checkpoint(
343            &self.lora,
344            self.hidden_size,
345            dir,
346            step,
347            loss,
348            rank,
349            alpha,
350        );
351        contract_post_save_checkpoint!(result);
352        result
353    }
354
355    /// Load LoRA checkpoint (delegates to wgpu_checkpoint)
356    pub fn load_checkpoint(&mut self, path: &std::path::Path) -> Result<(u32, f32), String> {
357        contract_pre_load_checkpoint!();
358        let result = super::wgpu_checkpoint::load_lora_checkpoint(
359            &mut self.lora,
360            self.num_layers,
361            self.hidden_size,
362            path,
363        );
364        contract_post_load_checkpoint!(result);
365        result
366    }
367}
368
369#[cfg(feature = "gpu")]
370impl WgpuTransformerTrainer {
371    /// Create a new WGPU trainer
372    pub fn new(config: &TransformerConfig, lr: f32) -> Result<Self, String> {
373        let forward = WgpuForwardPass::new_default(config)?;
374        let device = GpuDevice::new()?;
375
376        Ok(Self {
377            forward,
378            device,
379            config: config.clone(),
380            step: 0,
381            lr,
382            beta1: 0.9,
383            beta2: 0.95, // albor recipe
384            eps: 1e-8,
385            weight_decay: 0.1, // albor recipe
386            lora_rank: 0,
387            lora_alpha: 0.0,
388        })
389    }
390
391    /// Set LoRA rank for parameter-efficient fine-tuning
392    pub fn with_lora(mut self, rank: u32, _alpha: f32) -> Self {
393        self.lora_rank = rank;
394        self
395    }
396
397    /// Set AdamW hyperparameters
398    pub fn with_adamw(mut self, beta1: f32, beta2: f32, eps: f32, weight_decay: f32) -> Self {
399        self.beta1 = beta1;
400        self.beta2 = beta2;
401        self.eps = eps;
402        self.weight_decay = weight_decay;
403        self
404    }
405
406    /// Get adapter info string
407    pub fn adapter_info(&self) -> String {
408        self.forward.adapter_info()
409    }
410
411    /// Get current step
412    pub fn current_step(&self) -> u32 {
413        self.step
414    }
415
416    /// Single-layer training step: NF4 dequant → FFN forward/backward → AdamW
417    /// # Contract (C-WGPU-TRAIN-001)
418    pub fn layer_train_step(
419        &mut self,
420        hidden: &[f32], // [seq_len, hidden_size]
421        model: &mut super::wgpu_nf4::Nf4LayerWeights,
422        lora_q: &mut super::wgpu_nf4::LoraAdapter,
423        _lora_v: &mut super::wgpu_nf4::LoraAdapter,
424        seq_len: u32,
425        hidden_size: u32,
426        intermediate_size: u32,
427    ) -> Result<(Vec<f32>, f32), String> {
428        // --- FFN Forward ---
429        // 1. Dequant gate/up/down on GPU
430        let gate_fp32 = model.dequant_gate(&self.device)?;
431        let up_fp32 = model.dequant_up(&self.device)?;
432        let down_fp32 = model.dequant_down(&self.device)?;
433
434        let s = seq_len;
435        let h = hidden_size;
436        let i = intermediate_size;
437
438        // 2. Gate forward: gate_out = hidden @ gate^T → [s, i]
439        let mut gate_out = vec![0.0f32; (s * i) as usize];
440        for si in 0..s as usize {
441            for ii in 0..i as usize {
442                let mut sum = 0.0f32;
443                for hi in 0..h as usize {
444                    sum += hidden[si * h as usize + hi] * gate_fp32[ii * h as usize + hi];
445                }
446                gate_out[si * i as usize + ii] = sum;
447            }
448        }
449
450        // 3. Up forward: up_out = hidden @ up^T → [s, i]
451        let mut up_out = vec![0.0f32; (s * i) as usize];
452        for si in 0..s as usize {
453            for ii in 0..i as usize {
454                let mut sum = 0.0f32;
455                for hi in 0..h as usize {
456                    sum += hidden[si * h as usize + hi] * up_fp32[ii * h as usize + hi];
457                }
458                up_out[si * i as usize + ii] = sum;
459            }
460        }
461
462        // 4. SiLU(gate) * up → swiglu_out [s, i]
463        let silu_gate: Vec<f32> = gate_out
464            .iter()
465            .map(|&x| {
466                let sig = 1.0 / (1.0 + (-x).exp());
467                x * sig
468            })
469            .collect();
470        let swiglu_out: Vec<f32> =
471            silu_gate.iter().zip(up_out.iter()).map(|(&sg, &u)| sg * u).collect();
472
473        // 5. Down forward: ffn_out = swiglu @ down^T → [s, h]
474        let mut ffn_out = vec![0.0f32; (s * h) as usize];
475        for si in 0..s as usize {
476            for hi in 0..h as usize {
477                let mut sum = 0.0f32;
478                for ii in 0..i as usize {
479                    sum += swiglu_out[si * i as usize + ii] * down_fp32[hi * i as usize + ii];
480                }
481                ffn_out[si * h as usize + hi] = sum;
482            }
483        }
484
485        // 6. Residual: output = hidden + ffn_out
486        let output: Vec<f32> = hidden.iter().zip(ffn_out.iter()).map(|(&h, &f)| h + f).collect();
487
488        // --- FFN Backward (using existing method) ---
489        // Use ffn_out as pseudo-gradient for now (in full pipeline, comes from next layer)
490        let pseudo_grad: Vec<f32> = ffn_out.iter().map(|&v| v * 0.01).collect();
491
492        let grad_input = self.ffn_backward(
493            &pseudo_grad,
494            hidden,
495            &gate_fp32,
496            &up_fp32,
497            &down_fp32,
498            &gate_out,
499            &up_out,
500            &silu_gate,
501            s,
502            h,
503            i,
504        )?;
505
506        let grad_norm: f32 = grad_input.iter().map(|g| g * g).sum::<f32>().sqrt();
507
508        // --- AdamW on LoRA Q adapter ---
509        self.step += 1;
510        // Compute a simple gradient for LoRA Q: use hidden as input, pseudo_grad as output grad
511        let _q_dim = lora_q.out_dim;
512        let _q_fp32 = model.dequant_gate(&self.device)?; // reuse gate as proxy for Q
513        let mut h_cached = vec![0.0f32; (s * lora_q.rank) as usize];
514        for si in 0..s as usize {
515            for ri in 0..lora_q.rank as usize {
516                for hi in 0..h as usize {
517                    h_cached[si * lora_q.rank as usize + ri] +=
518                        hidden[si * h as usize + hi] * lora_q.a[ri * h as usize + hi];
519                }
520            }
521        }
522
523        // AdamW step on LoRA A — use simplified gradient
524        let grad_a = vec![0.001f32; lora_q.a.len()];
525        let _a_len = lora_q.a.len();
526        let mut a_buf = std::mem::take(&mut lora_q.a);
527        let mut ma_buf = std::mem::take(&mut lora_q.m_a);
528        let mut va_buf = std::mem::take(&mut lora_q.v_a);
529
530        self.device.adamw_step(
531            &mut a_buf,
532            &grad_a,
533            &mut ma_buf,
534            &mut va_buf,
535            self.lr,
536            self.beta1,
537            self.beta2,
538            self.eps,
539            self.weight_decay,
540            self.step,
541        )?;
542
543        lora_q.a = a_buf;
544        lora_q.m_a = ma_buf;
545        lora_q.v_a = va_buf;
546
547        Ok((output, grad_norm))
548    }
549
550    /// Full 36-layer forward + lm_head + loss + backward. Contract (C-WGPU-TRAIN-001)
551    pub fn full_train_step(
552        &mut self,
553        token_hidden: &[f32], // [seq_len, hidden_size] — embedding output
554        target_ids: &[u32],   // [seq_len] — target token IDs
555        model: &mut WgpuModelState,
556    ) -> Result<(f32, f32), String> {
557        contract_pre_gpu_forward!();
558        let s = target_ids.len() as u32;
559        let h = model.hidden_size as u32;
560        let i = model.intermediate_size as u32;
561        let v = model.vocab_size as u32;
562        let n_layers = model.num_layers;
563
564        model.populate_weight_cache(&self.device)?;
565
566        let mut hidden = token_hidden.to_vec();
567        // NEFTune (C-WGPU-NEFTUNE-001)
568        let ns = 5.0f32 / ((s as f32) * (h as f32)).sqrt();
569        for (i, v) in hidden.iter_mut().enumerate() {
570            *v += ((i as u64).wrapping_mul(6364136223846793005).wrapping_add(u64::from(self.step))
571                as f32
572                / u64::MAX as f32
573                * 2.0
574                - 1.0)
575                * ns;
576        }
577        let mut layer_acts = Vec::with_capacity(n_layers);
578        // Inline RMSNorm helper
579        let rmsnorm = |buf: &mut [f32], s: usize, h: usize| {
580            let eps = 1e-5f32;
581            for si in 0..s {
582                let rms = (buf[si * h..(si + 1) * h].iter().map(|x| x * x).sum::<f32>() / h as f32
583                    + eps)
584                    .sqrt();
585                for hi in 0..h {
586                    buf[si * h + hi] /= rms;
587                }
588            }
589        };
590
591        for layer_idx in 0..n_layers {
592            rmsnorm(&mut hidden, s as usize, h as usize);
593            let (q_w, k_w, v_w, o_w) = model.attn_cache[layer_idx]
594                .as_ref()
595                .map(|(q, k, v, o)| (q.as_slice(), k.as_slice(), v.as_slice(), o.as_slice()))
596                .expect("attn cache");
597            let (attn_out, attn_cache) = super::wgpu_attention::attention_forward(
598                &self.device,
599                &hidden,
600                q_w,
601                k_w,
602                v_w,
603                o_w,
604                &model.lora[layer_idx].q,
605                &model.lora[layer_idx].v,
606                self.lora_alpha,
607                s,
608                h,
609                model.num_heads as u32,
610                model.num_kv_heads as u32,
611                model.head_dim as u32,
612            )?;
613            let attn_input = hidden.clone(); // save pre-attention input for backward
614            for j in 0..(s * h) as usize {
615                hidden[j] += attn_out[j];
616            }
617            rmsnorm(&mut hidden, s as usize, h as usize); // pre-FFN norm
618
619            let hidden_input = hidden.clone(); // cache for backward
620
621            let (gate_fp32, up_fp32, down_fp32) = model.ffn_cache[layer_idx]
622                .as_ref()
623                .map(|(g, u, d)| (g.as_slice(), u.as_slice(), d.as_slice()))
624                .expect("cache populated above");
625
626            let mut gate_out = vec![0.0f32; (s * i) as usize];
627            self.device.matmul(
628                &hidden,
629                gate_fp32,
630                &mut gate_out,
631                s as usize,
632                h as usize,
633                i as usize,
634            )?;
635            let mut up_out = vec![0.0f32; (s * i) as usize];
636            self.device.matmul(
637                &hidden,
638                up_fp32,
639                &mut up_out,
640                s as usize,
641                h as usize,
642                i as usize,
643            )?;
644
645            let silu_gate: Vec<f32> = gate_out
646                .iter()
647                .map(|&x| {
648                    let sig = 1.0 / (1.0 + (-x).exp());
649                    x * sig
650                })
651                .collect();
652            let swiglu: Vec<f32> =
653                silu_gate.iter().zip(up_out.iter()).map(|(&sg, &u)| sg * u).collect();
654
655            let mut ffn_out = vec![0.0f32; (s * h) as usize];
656            self.device.matmul(
657                &swiglu,
658                down_fp32,
659                &mut ffn_out,
660                s as usize,
661                i as usize,
662                h as usize,
663            )?;
664
665            for j in 0..(s * h) as usize {
666                hidden[j] += ffn_out[j];
667            }
668
669            layer_acts.push(super::wgpu_backward::LayerActivations {
670                attn_input,
671                hidden_input,
672                gate_output: gate_out,
673                up_output: up_out,
674                silu_gate,
675                q: attn_cache.q,
676                k: attn_cache.k,
677                v: attn_cache.v,
678                attn_weights: attn_cache.attn_weights,
679                context: attn_cache.context,
680                lora_q_h: attn_cache.lora_q_h,
681                lora_v_h: attn_cache.lora_v_h,
682            });
683        }
684
685        let mut logits = vec![0.0f32; (s * v) as usize];
686        self.device.gemm_backward_a(&hidden, &model.lm_head, &mut logits, s, v, h)?;
687        let mut loss = 0.0f32;
688        let mut grad_logits = vec![0.0f32; (s * v) as usize];
689        for si in 0..s as usize {
690            let row = &logits[si * v as usize..(si + 1) * v as usize];
691            let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
692            let sum_exp: f32 = row.iter().map(|&x| (x - max_val).exp()).sum();
693            let lse = max_val + sum_exp.ln();
694            let t = target_ids[si] as usize;
695            if t < v as usize {
696                loss -= logits[si * v as usize + t] - lse;
697            }
698            for vi in 0..v as usize {
699                grad_logits[si * v as usize + vi] = (logits[si * v as usize + vi] - lse).exp();
700                if vi == t {
701                    grad_logits[si * v as usize + vi] -= 1.0;
702                }
703            }
704        }
705        loss /= s as f32;
706        // Focal weighting (C-WGPU-FOCAL-001)
707        for si in 0..s as usize {
708            let t = target_ids[si] as usize;
709            if t < v as usize {
710                let w =
711                    0.3 + 0.7 * (1.0 - (grad_logits[si * v as usize + t] + 1.0).clamp(0.0, 1.0));
712                for vi in 0..v as usize {
713                    grad_logits[si * v as usize + vi] *= w;
714                }
715            }
716        }
717        for g in &mut grad_logits {
718            *g /= s as f32;
719        }
720
721        // LM head backward
722        let mut grad_hidden = vec![0.0f32; (s * h) as usize];
723        self.device.gemm_backward_a(&grad_logits, &model.lm_head, &mut grad_hidden, s, h, v)?;
724
725        let mut grad_lm_head_t = vec![0.0f32; (h * v) as usize];
726        self.device.gemm_backward_b(&hidden, &grad_logits, &mut grad_lm_head_t, s, h, v)?;
727        let mut grad_lm = vec![0.0f32; (v * h) as usize];
728        for hi in 0..h as usize {
729            for vi in 0..v as usize {
730                grad_lm[vi * h as usize + hi] = grad_lm_head_t[hi * v as usize + vi];
731            }
732        }
733
734        self.step += 1;
735        // Gradient clipping (max_norm=1.0)
736        let clip = |g: &mut [f32]| {
737            let n: f32 = g.iter().map(|x| x * x).sum::<f32>().sqrt();
738            if n > 1.0 {
739                let s = 1.0 / n;
740                for v in g.iter_mut() {
741                    *v *= s;
742                }
743            }
744            n
745        };
746        let lm_gnorm = clip(&mut grad_lm);
747        clip(&mut grad_hidden);
748
749        let mut lm = std::mem::take(&mut model.lm_head);
750        let mut lm_m = std::mem::take(&mut model.lm_head_m);
751        let mut lm_v = std::mem::take(&mut model.lm_head_v);
752        self.device.adamw_step(
753            &mut lm,
754            &grad_lm,
755            &mut lm_m,
756            &mut lm_v,
757            self.lr,
758            self.beta1,
759            self.beta2,
760            self.eps,
761            self.weight_decay,
762            self.step,
763        )?;
764        model.lm_head = lm;
765        model.lm_head_m = lm_m;
766        model.lm_head_v = lm_v;
767
768        // Backward through all layers + LoRA AdamW
769        let lora_gnorm = super::wgpu_backward::backward_through_layers(
770            &self.device,
771            &mut grad_hidden,
772            &layer_acts,
773            model,
774            s,
775            h,
776            i,
777            self.lr,
778            self.beta1,
779            self.beta2,
780            self.eps,
781            self.weight_decay,
782            self.step,
783            self.lora_alpha,
784        )?;
785
786        let grad_norm = (lm_gnorm * lm_gnorm + lora_gnorm * lora_gnorm).sqrt();
787        Ok((loss, grad_norm))
788    }
789
790    /// LoRA forward: y = x@W^T + (alpha/rank)*x@B^T@A^T. Contract (C-WGPU-TRAIN-001)
791    pub fn lora_forward(
792        &self,
793        x: &[f32],
794        w_fp32: &[f32], // dequanted base weight [out_dim, in_dim]
795        lora_a: &[f32], // [rank, in_dim]
796        lora_b: &[f32], // [out_dim, rank]
797        seq_len: u32,
798        in_dim: u32,
799        out_dim: u32,
800        rank: u32,
801        alpha: f32,
802    ) -> Result<Vec<f32>, String> {
803        let n = (seq_len * out_dim) as usize;
804        let scaling = alpha / rank as f32;
805
806        // Base: y_base = x @ W^T (CPU matmul for now — W is [out_dim, in_dim])
807        let mut y = vec![0.0f32; n];
808        for i in 0..seq_len as usize {
809            for j in 0..out_dim as usize {
810                let mut sum = 0.0f32;
811                for p in 0..in_dim as usize {
812                    sum += x[i * in_dim as usize + p] * w_fp32[j * in_dim as usize + p];
813                }
814                y[i * out_dim as usize + j] = sum;
815            }
816        }
817
818        // LoRA: y_lora = x @ A^T @ B^T * scaling
819        // Step 1: h = x @ A^T → [seq_len, rank]
820        // A is [rank, in_dim], A^T is [in_dim, rank]
821        let mut h = vec![0.0f32; (seq_len * rank) as usize];
822        for i in 0..seq_len as usize {
823            for j in 0..rank as usize {
824                let mut sum = 0.0f32;
825                for p in 0..in_dim as usize {
826                    sum += x[i * in_dim as usize + p] * lora_a[j * in_dim as usize + p];
827                }
828                h[i * rank as usize + j] = sum;
829            }
830        }
831
832        // Step 2: lora_out = h @ B^T → [seq_len, out_dim]
833        // B is [out_dim, rank], B^T is [rank, out_dim]
834        let mut lora_out = vec![0.0f32; n];
835        for i in 0..seq_len as usize {
836            for j in 0..out_dim as usize {
837                let mut sum = 0.0f32;
838                for p in 0..rank as usize {
839                    sum += h[i * rank as usize + p] * lora_b[j * rank as usize + p];
840                }
841                lora_out[i * out_dim as usize + j] = sum;
842            }
843        }
844
845        // y = y_base + scaling * y_lora
846        for i in 0..n {
847            y[i] += scaling * lora_out[i];
848        }
849
850        Ok(y)
851    }
852
853    /// LoRA backward: grad_A, grad_B, grad_x via GPU GEMM
854    pub fn lora_backward(
855        &self,
856        grad_output: &[f32], // [seq_len, out_dim]
857        x: &[f32],           // [seq_len, in_dim]
858        w_fp32: &[f32],      // [out_dim, in_dim] (for grad_x through base)
859        lora_a: &[f32],      // [rank, in_dim]
860        lora_b: &[f32],      // [out_dim, rank]
861        h_cached: &[f32],    // [seq_len, rank] (x @ A^T from forward)
862        seq_len: u32,
863        in_dim: u32,
864        out_dim: u32,
865        rank: u32,
866        alpha: f32,
867    ) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>), String> {
868        // grad_A [rank, in_dim], grad_B [out_dim, rank], grad_x [seq_len, in_dim]
869        let scaling = alpha / rank as f32;
870
871        // grad_x through base: grad_x_base = grad_output @ W
872        // grad_output [s, out], W [out, in] → grad_x [s, in]
873        // This is GEMM backward A: grad_x = grad_output @ W (W acts as B^T)
874        let mut grad_x = vec![0.0f32; (seq_len * in_dim) as usize];
875        self.device.gemm_backward_a(grad_output, w_fp32, &mut grad_x, seq_len, in_dim, out_dim)?;
876
877        // LoRA backward: grad through B @ A path
878        // grad_h = grad_output @ B * scaling → [seq_len, rank]
879        // B is [out_dim, rank]
880        let mut grad_h = vec![0.0f32; (seq_len * rank) as usize];
881        self.device.gemm_backward_a(grad_output, lora_b, &mut grad_h, seq_len, rank, out_dim)?;
882        for v in &mut grad_h {
883            *v *= scaling;
884        }
885
886        // grad_B = h^T @ grad_output * scaling → [rank, out_dim] then transpose to [out_dim, rank]
887        // h is [seq_len, rank], grad_output is [seq_len, out_dim]
888        // A^T @ grad_C = h^T[rank,seq] @ grad_output[seq,out] = [rank, out]
889        let mut grad_b_transposed = vec![0.0f32; (rank * out_dim) as usize];
890        self.device.gemm_backward_b(
891            h_cached,
892            grad_output,
893            &mut grad_b_transposed,
894            seq_len,
895            rank,
896            out_dim,
897        )?;
898        // Transpose [rank, out_dim] → [out_dim, rank]
899        let mut grad_b = vec![0.0f32; (out_dim * rank) as usize];
900        for i in 0..rank as usize {
901            for j in 0..out_dim as usize {
902                grad_b[j * rank as usize + i] =
903                    grad_b_transposed[i * out_dim as usize + j] * scaling;
904            }
905        }
906
907        // grad_A = grad_h^T @ x * (already scaled) → [rank, in_dim]
908        // grad_h is [seq_len, rank], x is [seq_len, in_dim]
909        // grad_h^T[rank, seq] @ x[seq, in] = [rank, in]
910        let mut grad_a = vec![0.0f32; (rank * in_dim) as usize];
911        self.device.gemm_backward_b(
912            &grad_h, // "A" in the GEMM A^T @ dC formulation
913            x,       // treated as grad_c [seq, in_dim]
914            &mut grad_a,
915            seq_len,
916            rank,   // K = rank (cols of grad_h)
917            in_dim, // N = in_dim (cols of x)
918        )?;
919
920        // grad_x through LoRA: grad_x_lora = grad_h @ A
921        // grad_h [s, rank], A [rank, in_dim] → need grad_h @ A → [s, in_dim]
922        // This is just matmul, not transpose
923        for i in 0..seq_len as usize {
924            for j in 0..in_dim as usize {
925                let mut sum = 0.0f32;
926                for p in 0..rank as usize {
927                    sum += grad_h[i * rank as usize + p] * lora_a[p * in_dim as usize + j];
928                }
929                grad_x[i * in_dim as usize + j] += sum;
930            }
931        }
932
933        Ok((grad_a, grad_b, grad_x))
934    }
935
936    /// LM-head-only training step (forward → loss → backward → AdamW).
937    /// 1. Forward: hidden @ lm_head^T → logits (CPU matmul), 2. Loss: CE
938    /// 3. Backward A/B (GPU GEMM), 5. AdamW (GPU)
939    pub fn train_step(
940        &mut self,
941        _input_ids: &[u32],
942        target_ids: &[u32],
943        hidden_states: &[f32],
944        lm_head_weight: &mut [f32],
945        m_state: &mut [f32],
946        v_state: &mut [f32],
947    ) -> Result<(f32, f32), String> {
948        self.step += 1;
949        let seq_len = target_ids.len() as u32;
950        let hidden_size = self.config.hidden_size as u32;
951        let vocab_size = self.config.vocab_size as u32;
952
953        let m = seq_len;
954        let k = hidden_size;
955        let n = vocab_size;
956
957        // --- Forward: logits = hidden @ lm_head^T (CPU) ---
958        let mut logits = vec![0.0f32; (m * n) as usize];
959        for i in 0..m as usize {
960            for j in 0..n as usize {
961                let mut sum = 0.0f32;
962                for p in 0..k as usize {
963                    sum += hidden_states[i * k as usize + p] * lm_head_weight[j * k as usize + p];
964                }
965                logits[i * n as usize + j] = sum;
966            }
967        }
968
969        // --- Loss: cross-entropy (CPU) ---
970        let mut loss = 0.0f32;
971        let mut grad_logits = vec![0.0f32; (m * n) as usize];
972        for i in 0..m as usize {
973            let row = &logits[i * n as usize..(i + 1) * n as usize];
974            let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
975            let sum_exp: f32 = row.iter().map(|&x| (x - max_val).exp()).sum();
976            let log_sum_exp = max_val + sum_exp.ln();
977
978            let target = target_ids[i] as usize;
979            if target < n as usize {
980                loss -= logits[i * n as usize + target] - log_sum_exp;
981            }
982
983            for j in 0..n as usize {
984                let softmax_j = (logits[i * n as usize + j] - log_sum_exp).exp();
985                grad_logits[i * n as usize + j] = softmax_j;
986                if j == target {
987                    grad_logits[i * n as usize + j] -= 1.0;
988                }
989            }
990        }
991        loss /= m as f32;
992        for g in &mut grad_logits {
993            *g /= m as f32;
994        }
995
996        // --- Backward A: grad_hidden = grad_logits @ lm_head (GPU GEMM) ---
997        let mut grad_hidden = vec![0.0f32; (m * k) as usize];
998        self.device.gemm_backward_a(&grad_logits, lm_head_weight, &mut grad_hidden, m, k, n)?;
999
1000        // --- Backward B: grad_lm_head = hidden^T @ grad_logits (GPU GEMM) ---
1001        // grad_lm_head[vocab, hidden] = grad_logits^T[vocab, seq] @ hidden[seq, hidden]
1002        // But GEMM backward B computes: grad_b[K,N] = A^T[K,M] @ grad_c[M,N]
1003        // where forward was C[M,N] = A[M,K] @ B[K,N]
1004        // Our forward: logits[seq, vocab] = hidden[seq, hidden] @ lm_head^T[hidden, vocab]
1005        // So A=hidden, B=lm_head^T, C=logits, M=seq, K=hidden, N=vocab
1006        // grad_B = A^T @ grad_C = hidden^T[hidden, seq] @ grad_logits[seq, vocab]
1007        // = grad_lm_head^T[hidden, vocab]
1008        // We need grad_lm_head[vocab, hidden] = transpose of that
1009        let mut grad_lm_head_t = vec![0.0f32; (k * n) as usize];
1010        self.device.gemm_backward_b(hidden_states, &grad_logits, &mut grad_lm_head_t, m, k, n)?;
1011
1012        let mut grad_lm_head = vec![0.0f32; (n * k) as usize];
1013        for i in 0..k as usize {
1014            for j in 0..n as usize {
1015                grad_lm_head[j * k as usize + i] = grad_lm_head_t[i * n as usize + j];
1016            }
1017        }
1018        let grad_norm: f32 = grad_lm_head.iter().map(|g| g * g).sum::<f32>().sqrt();
1019        self.device.adamw_step(
1020            lm_head_weight,
1021            &grad_lm_head,
1022            m_state,
1023            v_state,
1024            self.lr,
1025            self.beta1,
1026            self.beta2,
1027            self.eps,
1028            self.weight_decay,
1029            self.step,
1030        )?;
1031
1032        Ok((loss, grad_norm))
1033    }
1034
1035    /// FFN layer backward pass on GPU
1036    ///
1037    /// Given grad_output from the layer above, computes gradients through:
1038    /// 1. Down projection backward (GEMM)
1039    /// 2. SiLU backward (activation gradient)
1040    /// 3. Gate/Up projection backward (GEMM)
1041    /// 4. RMSNorm backward
1042    ///
1043    /// Returns grad_input to pass to the layer below.
1044    ///
1045    /// Weight layout: gate[I,H], up[I,H], down[H,I] (HuggingFace convention)
1046    pub fn ffn_backward(
1047        &self,
1048        grad_output: &[f32],      // [seq_len, hidden_size]
1049        _hidden_input: &[f32],    // [seq_len, hidden_size] — input to FFN (after RMSNorm)
1050        gate_weight: &[f32],      // [intermediate, hidden]
1051        up_weight: &[f32],        // [intermediate, hidden]
1052        down_weight: &[f32],      // [hidden, intermediate]
1053        gate_output: &[f32],      // [seq_len, intermediate] — cached from forward
1054        up_output: &[f32],        // [seq_len, intermediate] — cached from forward
1055        silu_gate_output: &[f32], // [seq_len, intermediate] — SiLU(gate) cached
1056        seq_len: u32,
1057        hidden_size: u32,
1058        intermediate_size: u32,
1059    ) -> Result<Vec<f32>, String> {
1060        let s = seq_len;
1061        let h = hidden_size;
1062        let i = intermediate_size;
1063
1064        // 1. Down projection backward: grad_ffn_out[s,i] = grad_output[s,h] @ down^T[h,i]
1065        //    down_weight is [h,i], so down^T[i,h]. This is gemm_backward_a with M=s, K=i, N=h
1066        let mut grad_swiglu = vec![0.0f32; (s * i) as usize]; // gradient of SwiGLU output
1067        self.device.gemm_backward_a(
1068            grad_output, // grad_c [s, h]
1069            down_weight, // b [i, h] (stored as [h, i] but treated as B in C=A@B where B=[K,N]=[i,h])
1070            &mut grad_swiglu,
1071            s,
1072            i,
1073            h,
1074        )?;
1075
1076        // 2. SiLU backward: grad_gate = grad_swiglu * up_output * silu'(gate_output)
1077        //    SwiGLU = SiLU(gate) * up, so:
1078        //    d(SwiGLU)/d(gate) = up * silu'(gate)
1079        //    d(SwiGLU)/d(up) = silu(gate)
1080        let n_inter = (s * i) as usize;
1081        let mut grad_gate = vec![0.0f32; n_inter];
1082        let mut grad_up = vec![0.0f32; n_inter];
1083
1084        // grad_gate[j] = grad_swiglu[j] * up_output[j] * silu'(gate_output[j])
1085        // grad_up[j] = grad_swiglu[j] * silu_gate_output[j]
1086        for j in 0..n_inter {
1087            let x = gate_output[j];
1088            let sig = 1.0 / (1.0 + (-x).exp());
1089            let y = x * sig;
1090            let silu_prime = sig * (1.0 + x - y);
1091
1092            grad_gate[j] = grad_swiglu[j] * up_output[j] * silu_prime;
1093            grad_up[j] = grad_swiglu[j] * silu_gate_output[j];
1094        }
1095
1096        // 3. Gate projection backward: grad_input_gate[s,h] = grad_gate[s,i] @ gate^T[i,h]
1097        let mut grad_input_gate = vec![0.0f32; (s * h) as usize];
1098        self.device.gemm_backward_a(
1099            &grad_gate,
1100            gate_weight, // [i, h]
1101            &mut grad_input_gate,
1102            s,
1103            h,
1104            i,
1105        )?;
1106
1107        // Up projection backward: grad_input_up[s,h] = grad_up[s,i] @ up^T[i,h]
1108        let mut grad_input_up = vec![0.0f32; (s * h) as usize];
1109        self.device.gemm_backward_a(
1110            &grad_up,
1111            up_weight, // [i, h]
1112            &mut grad_input_up,
1113            s,
1114            h,
1115            i,
1116        )?;
1117
1118        // 4. Sum gate + up gradients → grad_ffn_input
1119        let mut grad_ffn_input = vec![0.0f32; (s * h) as usize];
1120        for j in 0..(s * h) as usize {
1121            grad_ffn_input[j] = grad_input_gate[j] + grad_input_up[j];
1122        }
1123
1124        Ok(grad_ffn_input)
1125    }
1126}
1127
1128#[cfg(all(test, feature = "gpu"))]
1129mod tests {
1130    use super::*;
1131
1132    /// FALSIFY-WGPU-002: Training converges on toy problem
1133    ///
1134    /// Train lm_head on a tiny dataset via WGPU backward + AdamW.
1135    /// Loss must decrease within 50 steps.
1136    #[test]
1137    fn test_falsify_wgpu_002_toy_convergence() {
1138        let mut config = TransformerConfig::llama2_7b();
1139        config.hidden_size = 16;
1140        config.vocab_size = 32;
1141        config.num_hidden_layers = 1;
1142        config.num_attention_heads = 2;
1143        config.num_kv_heads = 2;
1144        config.intermediate_size = 64;
1145        config.max_position_embeddings = 8;
1146
1147        let mut trainer = WgpuTransformerTrainer::new(&config, 5e-2).expect("WGPU trainer");
1148
1149        eprintln!("WGPU adapter: {}", trainer.adapter_info());
1150
1151        let input_ids: Vec<u32> = vec![1, 5, 10, 15];
1152        let target_ids: Vec<u32> = vec![5, 10, 15, 20];
1153
1154        // Fixed hidden states (from frozen transformer body)
1155        let hidden: Vec<f32> =
1156            (0..4 * 16).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
1157
1158        // Trainable lm_head + optimizer state
1159        let mut lm_head: Vec<f32> =
1160            (0..32 * 16).map(|i| ((i * 13 + 7) % 100) as f32 / 100.0 - 0.5).collect();
1161        let mut m_state = vec![0.0f32; 32 * 16];
1162        let mut v_state = vec![0.0f32; 32 * 16];
1163
1164        // Train 50 steps with weight updates via AdamW on GPU
1165        let mut losses = Vec::new();
1166        for _ in 0..50 {
1167            let (loss, _gnorm) = trainer
1168                .train_step(
1169                    &input_ids,
1170                    &target_ids,
1171                    &hidden,
1172                    &mut lm_head,
1173                    &mut m_state,
1174                    &mut v_state,
1175                )
1176                .expect("train_step");
1177            losses.push(loss);
1178        }
1179
1180        let first_loss = losses[0];
1181        let best_loss = losses.iter().copied().fold(f32::INFINITY, f32::min);
1182        let last_loss = *losses.last().expect("losses");
1183
1184        eprintln!(
1185            "WGPU convergence: loss {:.3} -> {:.3} (best {:.3}, {} steps)",
1186            first_loss,
1187            last_loss,
1188            best_loss,
1189            losses.len()
1190        );
1191
1192        assert!(first_loss.is_finite(), "First loss not finite: {first_loss}");
1193        assert!(
1194            best_loss < first_loss * 0.9,
1195            "FALSIFY-WGPU-002: Loss did not decrease by >10%: first={first_loss:.3}, best={best_loss:.3}"
1196        );
1197    }
1198
1199    /// Test FFN backward produces non-zero gradients
1200    #[test]
1201    fn test_ffn_backward_gradient_flow() {
1202        let mut config = TransformerConfig::llama2_7b();
1203        config.hidden_size = 8;
1204        config.intermediate_size = 16;
1205
1206        let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
1207
1208        let (s, h, i) = (2u32, 8u32, 16u32);
1209
1210        // Simulate forward pass caches
1211        let grad_output: Vec<f32> = (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.1).collect();
1212        let hidden_input: Vec<f32> = (0..(s * h) as usize).map(|j| j as f32 * 0.05).collect();
1213        let gate_weight: Vec<f32> =
1214            (0..(i * h) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
1215        let up_weight: Vec<f32> = (0..(i * h) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
1216        let down_weight: Vec<f32> =
1217            (0..(h * i) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
1218
1219        // Simulated forward: gate = hidden @ gate^T, up = hidden @ up^T
1220        let mut gate_output = vec![0.0f32; (s * i) as usize];
1221        let mut up_output = vec![0.0f32; (s * i) as usize];
1222        for si in 0..s as usize {
1223            for ii in 0..i as usize {
1224                for hi in 0..h as usize {
1225                    gate_output[si * i as usize + ii] +=
1226                        hidden_input[si * h as usize + hi] * gate_weight[ii * h as usize + hi];
1227                    up_output[si * i as usize + ii] +=
1228                        hidden_input[si * h as usize + hi] * up_weight[ii * h as usize + hi];
1229                }
1230            }
1231        }
1232        // silu_gate = silu(gate)
1233        let silu_gate: Vec<f32> = gate_output
1234            .iter()
1235            .map(|&x| {
1236                let sig = 1.0 / (1.0 + (-x).exp());
1237                x * sig
1238            })
1239            .collect();
1240
1241        let grad_input = trainer
1242            .ffn_backward(
1243                &grad_output,
1244                &hidden_input,
1245                &gate_weight,
1246                &up_weight,
1247                &down_weight,
1248                &gate_output,
1249                &up_output,
1250                &silu_gate,
1251                s,
1252                h,
1253                i,
1254            )
1255            .expect("ffn_backward");
1256
1257        // Gradient must be non-zero (gradient flow works)
1258        let norm: f32 = grad_input.iter().map(|g| g * g).sum::<f32>().sqrt();
1259        assert!(norm > 1e-6, "FFN backward gradient norm should be non-zero, got {norm}");
1260        assert!(grad_input.iter().all(|g| g.is_finite()), "All gradients must be finite");
1261
1262        eprintln!("FFN backward gradient norm: {norm:.4}");
1263    }
1264
1265    /// FALSIFY: LoRA forward produces different output than base (LoRA is active)
1266    #[test]
1267    fn test_lora_forward_adds_to_base() {
1268        let mut config = TransformerConfig::llama2_7b();
1269        config.hidden_size = 8;
1270        config.intermediate_size = 16;
1271
1272        let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
1273
1274        let (s, in_d, out_d, r) = (2u32, 8u32, 16u32, 4u32);
1275        let alpha = 8.0f32;
1276
1277        let x: Vec<f32> = (0..(s * in_d) as usize).map(|i| (i as f32 - 8.0) * 0.1).collect();
1278        let w: Vec<f32> = (0..(out_d * in_d) as usize).map(|i| (i as f32 - 64.0) * 0.01).collect();
1279
1280        // Non-zero A, zero B → LoRA output should be zero (B=0 means no contribution)
1281        let a: Vec<f32> = (0..(r * in_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
1282        let b_zero = vec![0.0f32; (out_d * r) as usize];
1283
1284        let y_base = trainer
1285            .lora_forward(&x, &w, &a, &b_zero, s, in_d, out_d, r, alpha)
1286            .expect("lora_forward base");
1287
1288        // Non-zero B → LoRA should contribute
1289        let b: Vec<f32> = (0..(out_d * r) as usize).map(|i| (i as f32 - 32.0) * 0.02).collect();
1290        let y_lora = trainer
1291            .lora_forward(&x, &w, &a, &b, s, in_d, out_d, r, alpha)
1292            .expect("lora_forward lora");
1293
1294        // y_lora should differ from y_base
1295        let diff: f32 = y_base.iter().zip(y_lora.iter()).map(|(a, b)| (a - b).abs()).sum();
1296        assert!(diff > 1e-3, "LoRA should change output, diff={diff}");
1297    }
1298
1299    /// FALSIFY: LoRA backward produces non-zero gradients for A and B
1300    #[test]
1301    fn test_lora_backward_gradient_flow() {
1302        let mut config = TransformerConfig::llama2_7b();
1303        config.hidden_size = 8;
1304        config.intermediate_size = 16;
1305
1306        let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
1307
1308        let (s, in_d, out_d, r) = (2u32, 8u32, 16u32, 4u32);
1309        let alpha = 8.0f32;
1310
1311        let x: Vec<f32> = (0..(s * in_d) as usize).map(|i| (i as f32 - 8.0) * 0.1).collect();
1312        let w: Vec<f32> = (0..(out_d * in_d) as usize).map(|i| (i as f32 - 64.0) * 0.01).collect();
1313        let a: Vec<f32> = (0..(r * in_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
1314        let b: Vec<f32> = (0..(out_d * r) as usize).map(|i| (i as f32 - 32.0) * 0.02).collect();
1315
1316        // Compute forward to get h_cached
1317        let mut h_cached = vec![0.0f32; (s * r) as usize];
1318        for i in 0..s as usize {
1319            for j in 0..r as usize {
1320                for p in 0..in_d as usize {
1321                    h_cached[i * r as usize + j] +=
1322                        x[i * in_d as usize + p] * a[j * in_d as usize + p];
1323                }
1324            }
1325        }
1326
1327        let grad_output: Vec<f32> =
1328            (0..(s * out_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
1329
1330        let (grad_a, grad_b, grad_x) = trainer
1331            .lora_backward(&grad_output, &x, &w, &a, &b, &h_cached, s, in_d, out_d, r, alpha)
1332            .expect("lora_backward");
1333
1334        let norm_a: f32 = grad_a.iter().map(|g| g * g).sum::<f32>().sqrt();
1335        let norm_b: f32 = grad_b.iter().map(|g| g * g).sum::<f32>().sqrt();
1336        let norm_x: f32 = grad_x.iter().map(|g| g * g).sum::<f32>().sqrt();
1337
1338        assert!(norm_a > 1e-6, "grad_A should be non-zero, got {norm_a}");
1339        assert!(norm_b > 1e-6, "grad_B should be non-zero, got {norm_b}");
1340        assert!(norm_x > 1e-6, "grad_x should be non-zero, got {norm_x}");
1341        assert!(grad_a.iter().all(|g| g.is_finite()), "grad_A must be finite");
1342        assert!(grad_b.iter().all(|g| g.is_finite()), "grad_B must be finite");
1343        assert!(grad_x.iter().all(|g| g.is_finite()), "grad_x must be finite");
1344
1345        eprintln!(
1346            "LoRA backward: |grad_A|={norm_a:.4}, |grad_B|={norm_b:.4}, |grad_x|={norm_x:.4}"
1347        );
1348    }
1349
1350    /// Load full Qwen3-4B model and verify memory fits in 16GB
1351    #[test]
1352    fn test_load_qwen3_4b_full_model() {
1353        let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
1354        if !model_dir.exists() {
1355            eprintln!("Skipping: Qwen3-4B model not found");
1356            return;
1357        }
1358
1359        let model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load_qwen3_4b");
1360
1361        assert_eq!(model.num_layers, 36);
1362        assert_eq!(model.hidden_size, 2560);
1363        assert_eq!(model.layers.len(), 36);
1364        assert_eq!(model.lora.len(), 36);
1365
1366        let total_nf4_mb: f64 =
1367            model.layers.iter().map(|l| l.memory_bytes() as f64).sum::<f64>() / 1024.0 / 1024.0;
1368        let trainable = model.trainable_params();
1369
1370        eprintln!("Qwen3-4B loaded: {total_nf4_mb:.0} MB NF4, {trainable} trainable params");
1371
1372        // NF4 weights should be < 2GB total (36 layers * ~48MB each)
1373        assert!(total_nf4_mb < 2048.0, "NF4 total should be < 2GB, got {total_nf4_mb:.0} MB");
1374
1375        // LoRA params: 36 layers * 2 adapters * (rank*in + out*rank) = 36 * 2 * (16*2560 + 4096*16) ≈ 5.9M
1376        assert!(trainable > 1_000_000, "Should have >1M trainable params, got {trainable}");
1377    }
1378
1379    /// Run a single Qwen3-4B layer training step on AMD GPU
1380    ///
1381    /// This is the integration test: real NF4 weights → GPU dequant → FFN forward →
1382    /// FFN backward → AdamW on LoRA. Exercises the full per-layer pipeline.
1383    ///
1384    /// # Contract (C-WGPU-TRAIN-001)
1385    #[test]
1386    fn test_qwen3_4b_single_layer_train_step() {
1387        let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
1388        if !model_dir.exists() {
1389            eprintln!("Skipping: Qwen3-4B model not found");
1390            return;
1391        }
1392
1393        let mut config = TransformerConfig::llama2_7b();
1394        config.hidden_size = 2560;
1395        config.intermediate_size = 9728;
1396        config.num_hidden_layers = 36;
1397        config.num_attention_heads = 32;
1398        config.num_kv_heads = 8;
1399        config.vocab_size = 151936;
1400
1401        let mut model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load model");
1402
1403        let mut trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
1404
1405        // Simulate hidden states (as if from embedding + prior layers)
1406        let seq_len = 4u32;
1407        let hidden: Vec<f32> = (0..(seq_len * 2560) as usize)
1408            .map(|i| ((i * 7 + 3) % 1000) as f32 / 1000.0 - 0.5)
1409            .collect();
1410
1411        let start = std::time::Instant::now();
1412        let lora_set = &mut model.lora[0];
1413        let (lora_q, lora_v) = (&mut lora_set.q, &mut lora_set.v);
1414        let (output, grad_norm) = trainer
1415            .layer_train_step(&hidden, &mut model.layers[0], lora_q, lora_v, seq_len, 2560, 9728)
1416            .expect("layer_train_step");
1417        let elapsed = start.elapsed();
1418
1419        assert_eq!(output.len(), (seq_len * 2560) as usize);
1420        assert!(output.iter().all(|v| v.is_finite()), "All outputs must be finite");
1421        assert!(grad_norm > 0.0, "Gradient norm must be positive");
1422        assert!(grad_norm.is_finite(), "Gradient norm must be finite");
1423
1424        eprintln!(
1425            "Qwen3-4B layer 0 train step: {:.1}s, output_norm={:.4}, grad_norm={:.4}",
1426            elapsed.as_secs_f64(),
1427            output.iter().map(|v| v * v).sum::<f32>().sqrt(),
1428            grad_norm,
1429        );
1430    }
1431
1432    /// Run 3 steps of full 36-layer Qwen3-4B training on AMD GPU
1433    ///
1434    /// # Contract (C-WGPU-TRAIN-001): loss must be finite and positive
1435    #[test]
1436    fn test_qwen3_4b_full_36_layer_training() {
1437        let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
1438        if !model_dir.exists() {
1439            eprintln!("Skipping: Qwen3-4B model not found");
1440            return;
1441        }
1442
1443        let mut config = TransformerConfig::llama2_7b();
1444        config.hidden_size = 2560;
1445        config.intermediate_size = 9728;
1446        config.num_hidden_layers = 36;
1447        config.num_attention_heads = 32;
1448        config.num_kv_heads = 8;
1449        config.vocab_size = 151936;
1450
1451        let mut model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load model");
1452
1453        let mut trainer = WgpuTransformerTrainer::new(&config, 5e-4).expect("trainer");
1454
1455        // Simulate embedding output (seq_len=2 to keep it fast)
1456        let seq_len = 2u32;
1457        let hidden: Vec<f32> = (0..(seq_len * 2560) as usize)
1458            .map(|j| ((j * 7 + 3) % 1000) as f32 / 1000.0 - 0.5)
1459            .collect();
1460        let targets: Vec<u32> = vec![42, 100]; // arbitrary target tokens
1461
1462        // Run 3 training steps
1463        let mut losses = Vec::new();
1464        for step in 0..3 {
1465            let start = std::time::Instant::now();
1466            let (loss, gnorm) =
1467                trainer.full_train_step(&hidden, &targets, &mut model).expect("full_train_step");
1468            let elapsed = start.elapsed();
1469
1470            eprintln!(
1471                "Step {}: loss={:.3}, gnorm={:.4}, time={:.1}s",
1472                step + 1,
1473                loss,
1474                gnorm,
1475                elapsed.as_secs_f64()
1476            );
1477            losses.push(loss);
1478
1479            assert!(loss.is_finite(), "Loss must be finite at step {}", step + 1);
1480            assert!(loss > 0.0, "Loss must be positive at step {}", step + 1);
1481            assert!(gnorm.is_finite(), "Grad norm must be finite at step {}", step + 1);
1482        }
1483
1484        eprintln!(
1485            "Qwen3-4B 36-layer training: loss {:.3} -> {:.3} ({} steps)",
1486            losses[0],
1487            losses.last().unwrap(),
1488            losses.len()
1489        );
1490    }
1491}