Skip to main content

entrenar/finetune/instruct_pipeline/
wgpu.rs

1//! WGPU GPU acceleration: `try_init_wgpu`, `wgpu_train_step`.
2
3#[cfg(feature = "gpu")]
4use super::{
5    clip_grad_norm_refs, InstructPipeline, InstructStepResult, Optimizer, Tensor,
6    TransformerConfig, WgpuTrainingState,
7};
8
9#[cfg(feature = "gpu")]
10impl InstructPipeline {
11    /// wgpu GPU training step (§26 WgpuTrainingPipeline)
12    ///
13    /// Uses CPU forward (model.forward) + GPU fused cross-entropy loss + CPU backward.
14    /// The GPU handles the loss computation (fused CE) and optimizer (AdamW).
15    /// Forward and backward GEMM through transformer layers stay on CPU for now —
16    /// full GPU forward/backward is Step 0d.2/0d.3 (WgslForwardPass/WgslBackwardPass).
17    ///
18    /// This is the integration point: proves the pipeline works end-to-end,
19    /// then incrementally moves forward/backward to GPU.
20    pub(super) fn wgpu_train_step(
21        &mut self,
22        full_ids: &[u32],
23        prompt_len: usize,
24        seq_len: usize,
25        vocab_size: usize,
26    ) -> InstructStepResult {
27        let loss_start = prompt_len.saturating_sub(1);
28        let loss_end = seq_len - 1;
29        let num_loss_tokens = loss_end.saturating_sub(loss_start);
30
31        if num_loss_tokens == 0 {
32            return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
33        }
34
35        // KAIZEN: instrument every phase to find bottleneck
36        let t0 = std::time::Instant::now();
37
38        // 1. Forward pass: CPU model.forward() (fast init, profiled per step)
39        let hidden_dim = self.wgpu_training.as_ref().unwrap().hidden_dim;
40        let _ = hidden_dim;
41
42        let logits_tensor = self.model.forward(full_ids);
43        let logits_data = logits_tensor.data().as_slice().expect("contiguous").to_vec();
44
45        let t1 = std::time::Instant::now();
46        eprintln!("[PROFILE] cpu_forward: {:.0}ms", t1.duration_since(t0).as_millis());
47
48        let t2 = t1;
49        let t3 = t1;
50
51        // Upload logits to GPU for fused cross-entropy
52        {
53            let wgpu = self.wgpu_training.as_ref().unwrap();
54            wgpu.trainer.queue_ref().write_buffer(
55                &wgpu.logits_buf,
56                0,
57                bytemuck::cast_slice(&logits_data[..seq_len * vocab_size]),
58            );
59        }
60
61        // 2. GPU fused cross-entropy loss
62        let wgpu = self.wgpu_training.as_ref().unwrap();
63
64        // Shifted labels: position i predicts token at i+1
65        let labels: Vec<u32> = (0..seq_len)
66            .map(|i| if i + 1 < full_ids.len() { full_ids[i + 1] } else { 0 })
67            .collect();
68        wgpu.trainer.queue_ref().write_buffer(&wgpu.labels_buf, 0, bytemuck::cast_slice(&labels));
69
70        let avg_loss = wgpu.cross_entropy.forward(
71            &wgpu.logits_buf,
72            &wgpu.labels_buf,
73            &wgpu.losses_buf,
74            &wgpu.logsumexp_buf,
75            seq_len as u32,
76            vocab_size as u32,
77            loss_start as u32,
78            loss_end as u32,
79        );
80
81        if !avg_loss.is_finite() {
82            eprintln!("[wgpu] NaN/Inf loss detected — skipping backward");
83            return InstructStepResult {
84                loss: 100.0,
85                num_response_tokens: num_loss_tokens,
86                perplexity: 1e6,
87            };
88        }
89
90        // 3. GPU fused cross-entropy backward (in-place into logits_buf)
91        wgpu.cross_entropy.backward(
92            &wgpu.logits_buf,
93            &wgpu.labels_buf,
94            &wgpu.logsumexp_buf,
95            seq_len as u32,
96            vocab_size as u32,
97            loss_start as u32,
98            loss_end as u32,
99        );
100
101        let t4 = std::time::Instant::now();
102        eprintln!("[PROFILE] fused_ce: {:.0}ms", t4.duration_since(t3).as_millis());
103
104        // Backward: use CPU autograd (simple, correct, profiled)
105        let wgpu = self.wgpu_training.as_ref().unwrap();
106        let grad_logits_data = wgpu.trainer.download(&wgpu.logits_buf);
107        logits_tensor
108            .set_grad(ndarray::Array1::from(grad_logits_data[..seq_len * vocab_size].to_vec()));
109        if let Some(op) = logits_tensor.backward_op() {
110            op.backward();
111        }
112
113        // Optimizer step on LoRA parameters
114        let mut params: Vec<&mut Tensor> = Vec::new();
115        for lora in &mut self.lora_layers {
116            params.extend(lora.trainable_params());
117        }
118        if let Some(max_norm) = self.config.gradient_clip_norm {
119            clip_grad_norm_refs(&mut params, max_norm);
120        }
121        self.optimizer.step_refs(&mut params);
122
123        let t5 = std::time::Instant::now();
124        eprintln!("[PROFILE] lm_head_backward: {:.0}ms", t5.duration_since(t4).as_millis());
125
126        let t6 = std::time::Instant::now();
127        eprintln!(
128            "[PROFILE] total_step: {:.0}ms (embed={:.0} fwd={:.0} lm={:.0} ce={:.0} bwd={:.0})",
129            t6.duration_since(t0).as_millis(),
130            t1.duration_since(t0).as_millis(),
131            t2.duration_since(t1).as_millis(),
132            t3.duration_since(t2).as_millis(),
133            t4.duration_since(t3).as_millis(),
134            t5.duration_since(t4).as_millis(),
135        );
136
137        InstructStepResult {
138            loss: avg_loss,
139            num_response_tokens: num_loss_tokens,
140            perplexity: avg_loss.exp().min(1e6),
141        }
142    }
143
144    // ── wgpu GPU acceleration (§26 WgpuTrainingPipeline) ────────────────
145
146    pub(super) fn try_init_wgpu(&mut self, _model_config: &TransformerConfig) {
147        use crate::autograd::wgpu_cross_entropy::WgslCrossEntropy;
148        use crate::autograd::wgpu_training::WgpuTrainer;
149
150        let trainer = match WgpuTrainer::new() {
151            Ok(t) => t,
152            Err(e) => {
153                eprintln!("[wgpu] Failed to init: {e} — using CPU");
154                return;
155            }
156        };
157
158        let seq = self.config.max_seq_len as u32;
159        let vocab = _model_config.vocab_size as u32;
160        let hidden = _model_config.hidden_size as u32;
161        let num_layers = _model_config.num_hidden_layers;
162        let num_heads = _model_config.num_attention_heads as u32;
163        let num_kv_heads = _model_config.num_kv_heads as u32;
164        let head_dim = (hidden / num_heads);
165        let inter = _model_config.intermediate_size as u32;
166
167        // Create WgslForwardPass with persistent weight buffers + tiled GEMM
168        let mut fwd = trueno::backends::gpu::WgslForwardPass::new(
169            trainer.device_ref().clone(),
170            trainer.queue_ref().clone(),
171            hidden as usize,
172            num_heads as usize,
173            num_kv_heads as usize,
174            head_dim as usize,
175            inter as usize,
176        );
177
178        // KAIZEN: Only upload norm weights (tiny: 14 KB each, 28 layers = ~800 KB total).
179        let mut uploaded = 0usize;
180        for (name, tensor) in self.model.named_parameters() {
181            let data = match tensor.data().as_slice() {
182                Some(s) => s,
183                None => continue,
184            };
185
186            let gpu_name = name
187                .replace("model.layers.", "layer.")
188                .replace(".input_layernorm.weight", ".attn_norm")
189                .replace(".post_attention_layernorm.weight", ".ffn_norm")
190                .replace(".self_attn.", ".")
191                .replace(".mlp.", ".")
192                .replace(".weight", "");
193
194            if gpu_name.ends_with(".attn_norm") || gpu_name.ends_with(".ffn_norm") {
195                fwd.upload_weight(&gpu_name, data);
196                uploaded += 1;
197            }
198        }
199
200        fwd.init_kv_cache(num_layers);
201
202        eprintln!(
203            "[wgpu] Uploaded {uploaded} norm weights ({num_layers} layers, projections on-demand)"
204        );
205
206        let make_buf = |size: u64, label: &str| -> trueno::backends::gpu::wgpu::Buffer {
207            trainer.device_ref().create_buffer(&trueno::backends::gpu::wgpu::BufferDescriptor {
208                label: Some(label),
209                size: size * 4,
210                usage: trueno::backends::gpu::wgpu::BufferUsages::STORAGE
211                    | trueno::backends::gpu::wgpu::BufferUsages::COPY_SRC
212                    | trueno::backends::gpu::wgpu::BufferUsages::COPY_DST,
213                mapped_at_creation: false,
214            })
215        };
216
217        let ce = WgslCrossEntropy::new(trainer.device_ref().clone(), trainer.queue_ref().clone());
218
219        // KAIZEN: precompute lm_head + transpose, upload to GPU ONCE
220        let lm_head_raw = self.model.lm_head_weight_slice();
221        let h = hidden as usize;
222        let v = vocab as usize;
223        let mut lm_head_transposed = vec![0.0f32; h * v];
224        for vi in 0..v {
225            for hi in 0..h {
226                lm_head_transposed[hi * v + vi] = lm_head_raw[vi * h + hi];
227            }
228        }
229        let lm_head_gpu = trainer.upload(lm_head_raw);
230        let lm_head_t_gpu = trainer.upload(&lm_head_transposed);
231        drop(lm_head_transposed);
232        eprintln!(
233            "[wgpu] Training initialized (seq={seq}, vocab={vocab}, layers={num_layers}, lm_head on GPU)"
234        );
235
236        self.wgpu_training = Some(WgpuTrainingState {
237            fwd,
238            logits_buf: make_buf(u64::from(seq) * u64::from(vocab), "logits"),
239            labels_buf: make_buf(u64::from(seq), "labels"),
240            losses_buf: make_buf(u64::from(seq), "losses"),
241            logsumexp_buf: make_buf(u64::from(seq), "logsumexp"),
242            cross_entropy: ce,
243            trainer,
244            lm_head_gpu,
245            lm_head_t_gpu,
246            num_layers,
247            hidden_dim: hidden as usize,
248            vocab_size: vocab as usize,
249        });
250    }
251}