entrenar/finetune/instruct_pipeline/
wgpu.rs1#[cfg(feature = "gpu")]
4use super::{
5 clip_grad_norm_refs, InstructPipeline, InstructStepResult, Optimizer, Tensor,
6 TransformerConfig, WgpuTrainingState,
7};
8
9#[cfg(feature = "gpu")]
10impl InstructPipeline {
11 pub(super) fn wgpu_train_step(
21 &mut self,
22 full_ids: &[u32],
23 prompt_len: usize,
24 seq_len: usize,
25 vocab_size: usize,
26 ) -> InstructStepResult {
27 let loss_start = prompt_len.saturating_sub(1);
28 let loss_end = seq_len - 1;
29 let num_loss_tokens = loss_end.saturating_sub(loss_start);
30
31 if num_loss_tokens == 0 {
32 return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
33 }
34
35 let t0 = std::time::Instant::now();
37
38 let hidden_dim = self.wgpu_training.as_ref().unwrap().hidden_dim;
40 let _ = hidden_dim;
41
42 let logits_tensor = self.model.forward(full_ids);
43 let logits_data = logits_tensor.data().as_slice().expect("contiguous").to_vec();
44
45 let t1 = std::time::Instant::now();
46 eprintln!("[PROFILE] cpu_forward: {:.0}ms", t1.duration_since(t0).as_millis());
47
48 let t2 = t1;
49 let t3 = t1;
50
51 {
53 let wgpu = self.wgpu_training.as_ref().unwrap();
54 wgpu.trainer.queue_ref().write_buffer(
55 &wgpu.logits_buf,
56 0,
57 bytemuck::cast_slice(&logits_data[..seq_len * vocab_size]),
58 );
59 }
60
61 let wgpu = self.wgpu_training.as_ref().unwrap();
63
64 let labels: Vec<u32> = (0..seq_len)
66 .map(|i| if i + 1 < full_ids.len() { full_ids[i + 1] } else { 0 })
67 .collect();
68 wgpu.trainer.queue_ref().write_buffer(&wgpu.labels_buf, 0, bytemuck::cast_slice(&labels));
69
70 let avg_loss = wgpu.cross_entropy.forward(
71 &wgpu.logits_buf,
72 &wgpu.labels_buf,
73 &wgpu.losses_buf,
74 &wgpu.logsumexp_buf,
75 seq_len as u32,
76 vocab_size as u32,
77 loss_start as u32,
78 loss_end as u32,
79 );
80
81 if !avg_loss.is_finite() {
82 eprintln!("[wgpu] NaN/Inf loss detected — skipping backward");
83 return InstructStepResult {
84 loss: 100.0,
85 num_response_tokens: num_loss_tokens,
86 perplexity: 1e6,
87 };
88 }
89
90 wgpu.cross_entropy.backward(
92 &wgpu.logits_buf,
93 &wgpu.labels_buf,
94 &wgpu.logsumexp_buf,
95 seq_len as u32,
96 vocab_size as u32,
97 loss_start as u32,
98 loss_end as u32,
99 );
100
101 let t4 = std::time::Instant::now();
102 eprintln!("[PROFILE] fused_ce: {:.0}ms", t4.duration_since(t3).as_millis());
103
104 let wgpu = self.wgpu_training.as_ref().unwrap();
106 let grad_logits_data = wgpu.trainer.download(&wgpu.logits_buf);
107 logits_tensor
108 .set_grad(ndarray::Array1::from(grad_logits_data[..seq_len * vocab_size].to_vec()));
109 if let Some(op) = logits_tensor.backward_op() {
110 op.backward();
111 }
112
113 let mut params: Vec<&mut Tensor> = Vec::new();
115 for lora in &mut self.lora_layers {
116 params.extend(lora.trainable_params());
117 }
118 if let Some(max_norm) = self.config.gradient_clip_norm {
119 clip_grad_norm_refs(&mut params, max_norm);
120 }
121 self.optimizer.step_refs(&mut params);
122
123 let t5 = std::time::Instant::now();
124 eprintln!("[PROFILE] lm_head_backward: {:.0}ms", t5.duration_since(t4).as_millis());
125
126 let t6 = std::time::Instant::now();
127 eprintln!(
128 "[PROFILE] total_step: {:.0}ms (embed={:.0} fwd={:.0} lm={:.0} ce={:.0} bwd={:.0})",
129 t6.duration_since(t0).as_millis(),
130 t1.duration_since(t0).as_millis(),
131 t2.duration_since(t1).as_millis(),
132 t3.duration_since(t2).as_millis(),
133 t4.duration_since(t3).as_millis(),
134 t5.duration_since(t4).as_millis(),
135 );
136
137 InstructStepResult {
138 loss: avg_loss,
139 num_response_tokens: num_loss_tokens,
140 perplexity: avg_loss.exp().min(1e6),
141 }
142 }
143
144 pub(super) fn try_init_wgpu(&mut self, _model_config: &TransformerConfig) {
147 use crate::autograd::wgpu_cross_entropy::WgslCrossEntropy;
148 use crate::autograd::wgpu_training::WgpuTrainer;
149
150 let trainer = match WgpuTrainer::new() {
151 Ok(t) => t,
152 Err(e) => {
153 eprintln!("[wgpu] Failed to init: {e} — using CPU");
154 return;
155 }
156 };
157
158 let seq = self.config.max_seq_len as u32;
159 let vocab = _model_config.vocab_size as u32;
160 let hidden = _model_config.hidden_size as u32;
161 let num_layers = _model_config.num_hidden_layers;
162 let num_heads = _model_config.num_attention_heads as u32;
163 let num_kv_heads = _model_config.num_kv_heads as u32;
164 let head_dim = (hidden / num_heads);
165 let inter = _model_config.intermediate_size as u32;
166
167 let mut fwd = trueno::backends::gpu::WgslForwardPass::new(
169 trainer.device_ref().clone(),
170 trainer.queue_ref().clone(),
171 hidden as usize,
172 num_heads as usize,
173 num_kv_heads as usize,
174 head_dim as usize,
175 inter as usize,
176 );
177
178 let mut uploaded = 0usize;
180 for (name, tensor) in self.model.named_parameters() {
181 let data = match tensor.data().as_slice() {
182 Some(s) => s,
183 None => continue,
184 };
185
186 let gpu_name = name
187 .replace("model.layers.", "layer.")
188 .replace(".input_layernorm.weight", ".attn_norm")
189 .replace(".post_attention_layernorm.weight", ".ffn_norm")
190 .replace(".self_attn.", ".")
191 .replace(".mlp.", ".")
192 .replace(".weight", "");
193
194 if gpu_name.ends_with(".attn_norm") || gpu_name.ends_with(".ffn_norm") {
195 fwd.upload_weight(&gpu_name, data);
196 uploaded += 1;
197 }
198 }
199
200 fwd.init_kv_cache(num_layers);
201
202 eprintln!(
203 "[wgpu] Uploaded {uploaded} norm weights ({num_layers} layers, projections on-demand)"
204 );
205
206 let make_buf = |size: u64, label: &str| -> trueno::backends::gpu::wgpu::Buffer {
207 trainer.device_ref().create_buffer(&trueno::backends::gpu::wgpu::BufferDescriptor {
208 label: Some(label),
209 size: size * 4,
210 usage: trueno::backends::gpu::wgpu::BufferUsages::STORAGE
211 | trueno::backends::gpu::wgpu::BufferUsages::COPY_SRC
212 | trueno::backends::gpu::wgpu::BufferUsages::COPY_DST,
213 mapped_at_creation: false,
214 })
215 };
216
217 let ce = WgslCrossEntropy::new(trainer.device_ref().clone(), trainer.queue_ref().clone());
218
219 let lm_head_raw = self.model.lm_head_weight_slice();
221 let h = hidden as usize;
222 let v = vocab as usize;
223 let mut lm_head_transposed = vec![0.0f32; h * v];
224 for vi in 0..v {
225 for hi in 0..h {
226 lm_head_transposed[hi * v + vi] = lm_head_raw[vi * h + hi];
227 }
228 }
229 let lm_head_gpu = trainer.upload(lm_head_raw);
230 let lm_head_t_gpu = trainer.upload(&lm_head_transposed);
231 drop(lm_head_transposed);
232 eprintln!(
233 "[wgpu] Training initialized (seq={seq}, vocab={vocab}, layers={num_layers}, lm_head on GPU)"
234 );
235
236 self.wgpu_training = Some(WgpuTrainingState {
237 fwd,
238 logits_buf: make_buf(u64::from(seq) * u64::from(vocab), "logits"),
239 labels_buf: make_buf(u64::from(seq), "labels"),
240 losses_buf: make_buf(u64::from(seq), "losses"),
241 logsumexp_buf: make_buf(u64::from(seq), "logsumexp"),
242 cross_entropy: ce,
243 trainer,
244 lm_head_gpu,
245 lm_head_t_gpu,
246 num_layers,
247 hidden_dim: hidden as usize,
248 vocab_size: vocab as usize,
249 });
250 }
251}