entrenar/finetune/instruct_pipeline/
training.rs1#[allow(clippy::wildcard_imports)]
2use super::*;
3
4#[cfg(feature = "cuda")]
5use crate::autograd::cuda_forward::gemm_forward;
6#[cfg(feature = "cuda")]
7use crate::autograd::cuda_optim::fused_causal_cross_entropy_cuda;
8
9impl InstructPipeline {
10 pub fn train_step(&mut self, prompt_ids: &[u32], response_ids: &[u32]) -> InstructStepResult {
18 self.profiler.begin_step();
19 let full_ids: Vec<u32> = prompt_ids.iter().chain(response_ids.iter()).copied().collect();
20
21 let prompt_len = prompt_ids.len();
22 let response_len = response_ids.len();
23
24 if response_len == 0 || full_ids.len() < 2 {
25 self.profiler.finish_step();
26 return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
27 }
28
29 let full_ids = if full_ids.len() > self.config.max_seq_len {
30 full_ids[..self.config.max_seq_len].to_vec()
31 } else {
32 full_ids
33 };
34 let seq_len = full_ids.len();
35 let vocab_size = self.model.config().vocab_size;
36
37 let prompt_len = prompt_len.min(seq_len);
40
41 #[cfg(feature = "cuda")]
46 if self.cuda_blocks.is_some() {
47 let result = self.cuda_train_step(&full_ids, prompt_len, seq_len, vocab_size);
48 self.profiler.finish_step();
49 return result;
50 }
51
52 #[cfg(feature = "gpu")]
54 if self.wgpu_training.is_some() {
55 return self.wgpu_train_step(&full_ids, prompt_len, seq_len, vocab_size);
56 }
57
58 for lora in &mut self.lora_layers {
62 for param in lora.trainable_params() {
63 param.zero_grad();
64 }
65 }
66
67 let logits = self.model.forward(&full_ids);
69 let logits_data = logits.data().as_slice().expect("contiguous logits").to_vec();
70
71 let loss_start = prompt_len.saturating_sub(1);
73 let loss_end = seq_len - 1;
74 let num_loss_tokens = loss_end.saturating_sub(loss_start);
75
76 if num_loss_tokens == 0 {
77 return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
78 }
79
80 let (avg_loss, grad_logits) =
81 Self::compute_causal_lm_loss(&logits_data, &full_ids, loss_start, loss_end, vocab_size);
82
83 logits.set_grad(ndarray::Array1::from(grad_logits));
85 if let Some(op) = logits.backward_op() {
86 op.backward();
87 }
88
89 let mut params: Vec<&mut Tensor> = Vec::new();
91 for lora in &mut self.lora_layers {
92 params.extend(lora.trainable_params());
93 }
94
95 if let Some(max_norm) = self.config.gradient_clip_norm {
96 clip_grad_norm_refs(&mut params, max_norm);
97 }
98
99 self.optimizer.step_refs(&mut params);
100
101 InstructStepResult {
102 loss: avg_loss,
103 num_response_tokens: num_loss_tokens,
104 perplexity: avg_loss.exp().min(1e6),
105 }
106 }
107 #[cfg(feature = "cuda")]
115 fn cuda_train_step(
116 &mut self,
117 full_ids: &[u32],
118 prompt_len: usize,
119 seq_len: usize,
120 vocab_size: usize,
121 ) -> InstructStepResult {
122 let max_pos = self.model.config().max_position_embeddings.min(512);
124 let seq_len = seq_len.min(max_pos);
125 let prompt_len = prompt_len.min(seq_len);
126 let loss_start = prompt_len.saturating_sub(1);
127 let loss_end = seq_len - 1;
128 let num_loss_tokens = loss_end.saturating_sub(loss_start);
129
130 if num_loss_tokens == 0 {
131 return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
132 }
133
134 let has_gpu_embed = self.gpu_training.as_ref().is_some_and(|t| {
137 t.embed_original.len() >= self.model.config().hidden_size * vocab_size
138 });
139
140 if !has_gpu_embed {
141 return self.cuda_train_step_cpu_loss(
142 full_ids,
143 loss_start,
144 loss_end,
145 num_loss_tokens,
146 seq_len,
147 vocab_size,
148 );
149 }
150
151 if self.profiler.is_enabled() {
153 if let Some(ref mut scratch) = self.shared_scratch {
154 scratch.op_profiling_enabled = true;
155 scratch.op_us = [0u64; 16];
156 }
157 }
158
159 self.profiler.begin(StepProfiler::FORWARD);
161 if !self.forward_logits_gpu_resident(full_ids) {
162 self.profiler.end(StepProfiler::FORWARD);
163 eprintln!("[CUDA] GPU forward failed, falling back to CPU for this step");
164 return self.cuda_train_step_cpu_loss(
165 full_ids,
166 loss_start,
167 loss_end,
168 num_loss_tokens,
169 seq_len,
170 vocab_size,
171 );
172 }
173 self.profiler.end(StepProfiler::FORWARD);
174
175 let targets: Vec<u32> = (0..seq_len)
177 .map(|pos| if pos + 1 < full_ids.len() { full_ids[pos + 1] } else { 0 })
178 .collect();
179
180 let scale = 1.0 / num_loss_tokens as f32;
181
182 self.profiler.begin(StepProfiler::LOSS);
183 let avg_loss = (|| -> Option<f32> {
184 let trainer = self.cuda_trainer.as_ref()?;
185 let stream = trainer.stream();
186 let training = self.gpu_training.as_mut()?;
187 fused_causal_cross_entropy_cuda(
188 &mut training.logits_buf,
189 &targets,
190 seq_len as u32,
191 vocab_size as u32,
192 loss_start as u32,
193 loss_end as u32,
194 scale,
195 stream,
196 )
197 .ok()
198 })();
199 self.profiler.end(StepProfiler::LOSS);
200
201 let avg_loss = match avg_loss {
202 Some(l) if l.is_finite() => {
203 eprintln!("[CUDA] loss={l:.4} (finite, proceeding with backward)");
204 l
205 }
206 Some(l) => {
207 eprintln!("[CUDA] NaN/Inf loss detected (loss={l}) — skipping backward pass");
208 return InstructStepResult {
209 loss: 100.0,
210 num_response_tokens: num_loss_tokens,
211 perplexity: 1e6,
212 };
213 }
214 None => {
215 eprintln!("[CUDA] fused causal cross-entropy failed — falling back to CPU");
216 return self.cuda_train_step_cpu_loss(
217 full_ids,
218 loss_start,
219 loss_end,
220 num_loss_tokens,
221 seq_len,
222 vocab_size,
223 );
224 }
225 };
226
227 self.profiler.begin(StepProfiler::LM_BWD);
229 let hidden_size = self.model.config().hidden_size;
230
231 let gemm_ok = (|| -> Option<()> {
232 let trainer = self.cuda_trainer.as_ref()?;
233 let stream = trainer.stream();
234 let training = self.gpu_training.as_mut()?;
235 if training.embed_original.len() < vocab_size * hidden_size {
236 return None;
237 }
238 gemm_forward(
239 &training.logits_buf,
240 &training.embed_original,
241 &mut training.grad_hidden_buf,
242 seq_len as u32,
243 vocab_size as u32,
244 hidden_size as u32,
245 stream,
246 )
247 .map_err(|e| eprintln!("[CUDA] lm_head backward GEMM failed: {e}"))
248 .ok()?;
249 Some(())
250 })();
251
252 self.profiler.end(StepProfiler::LM_BWD);
253
254 if gemm_ok.is_none() {
255 let cpu_ok = (|| -> Option<()> {
257 let trainer = self.cuda_trainer.as_ref()?;
258 let training = self.gpu_training.as_mut()?;
259 let embed = self.model.embed_tokens.weight.data();
260 let embed = embed.as_slice().expect("contiguous embed");
261 super::super::gpu_backward_fallback::cpu_lmhead_backward(
262 trainer,
263 &training.logits_buf,
264 &mut training.grad_hidden_buf,
265 embed,
266 seq_len,
267 vocab_size,
268 hidden_size,
269 trainer.stream(),
270 )
271 })();
272 if cpu_ok.is_none() {
273 return InstructStepResult {
274 loss: avg_loss,
275 num_response_tokens: num_loss_tokens,
276 perplexity: avg_loss.exp().min(1e6),
277 };
278 }
279 }
280
281 self.profiler.begin(StepProfiler::BLK_BWD);
283 if self.config.quantize_nf4 {
284 self.backward_nf4_gpu_blocks_gpu_resident(seq_len);
285 }
286 self.profiler.end(StepProfiler::BLK_BWD);
287
288 if let Some(ref training) = self.gpu_training {
290 self.profiler.record_layer_times(
291 &training.profiler_layer_fwd_us,
292 &training.profiler_layer_bwd_us,
293 );
294 }
295
296 if let Some(ref scratch) = self.shared_scratch {
298 if scratch.op_profiling_enabled {
299 for (i, &us) in scratch.op_us.iter().enumerate() {
300 if us > 0 {
301 self.profiler.end_op_raw(i, us);
302 }
303 }
304 }
305 }
306
307 InstructStepResult {
308 loss: avg_loss,
309 num_response_tokens: num_loss_tokens,
310 perplexity: avg_loss.exp().min(1e6),
311 }
312 }
313 #[cfg(feature = "cuda")]
316 fn cuda_train_step_cpu_loss(
317 &mut self,
318 full_ids: &[u32],
319 loss_start: usize,
320 loss_end: usize,
321 num_loss_tokens: usize,
322 seq_len: usize,
323 vocab_size: usize,
324 ) -> InstructStepResult {
325 let has_gpu_embed = self.gpu_training.as_ref().is_some_and(|t| {
328 t.embed_original.len() >= vocab_size * self.model.config().hidden_size
329 });
330
331 let logits_data = if has_gpu_embed {
332 match self.forward_logits_gpu(full_ids) {
333 Some(data) => data,
334 None => {
335 let logits = self.model.forward(full_ids);
336 logits.data().as_slice().expect("contiguous logits").to_vec()
337 }
338 }
339 } else {
340 match self.forward_inference_saving_inputs(full_ids) {
342 Some(data) => data,
343 None => {
344 let logits = self.model.forward(full_ids);
345 logits.data().as_slice().expect("contiguous logits").to_vec()
346 }
347 }
348 };
349
350 let (avg_loss, grad_logits) =
351 Self::compute_causal_lm_loss(&logits_data, full_ids, loss_start, loss_end, vocab_size);
352
353 if !avg_loss.is_finite() {
354 return InstructStepResult {
355 loss: 100.0,
356 num_response_tokens: num_loss_tokens,
357 perplexity: 1e6,
358 };
359 }
360
361 let hidden_size = self.model.config().hidden_size;
362
363 let grad_hidden = (|| -> Option<Vec<f32>> {
364 let trainer = self.cuda_trainer.as_ref()?;
365 let stream = trainer.stream();
366 let training = self.gpu_training.as_mut()?;
367 if training.logits_buf.len() < grad_logits.len() {
368 return None;
369 }
370 training
371 .logits_buf
372 .copy_from_host_at(&grad_logits, 0)
373 .map_err(|e| eprintln!("[CUDA] lm_head backward: grad_logits upload failed: {e}"))
374 .ok()?;
375 if training.embed_original.len() < vocab_size * hidden_size {
376 return None;
377 }
378 gemm_forward(
379 &training.logits_buf,
380 &training.embed_original,
381 &mut training.grad_hidden_buf,
382 seq_len as u32,
383 vocab_size as u32,
384 hidden_size as u32,
385 stream,
386 )
387 .map_err(|e| eprintln!("[CUDA] lm_head backward GEMM failed: {e}"))
388 .ok()?;
389 stream.synchronize().ok()?;
390 let full_grad = trainer.download(&training.grad_hidden_buf).ok()?;
391 Some(full_grad[..seq_len * hidden_size].to_vec())
392 })();
393
394 let grad_hidden = match grad_hidden {
395 Some(g) => g,
396 None => {
397 let hidden_size = self.model.config().hidden_size;
398 let lm_weight =
399 self.model.lm_head.as_ref().unwrap_or(&self.model.embed_tokens.weight);
400 let lm_data = lm_weight.data();
401 let lm_slice = lm_data.as_slice().expect("contiguous lm_head");
402 crate::autograd::ops::matmul::matmul_compute(
403 &grad_logits[..seq_len * vocab_size],
404 lm_slice,
405 seq_len,
406 vocab_size,
407 hidden_size,
408 )
409 }
410 };
411
412 if self.config.quantize_nf4 {
413 let grad_nz = grad_hidden.iter().filter(|&&x| x != 0.0).count();
414 static BWD_LOG: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
415 if BWD_LOG.fetch_add(1, std::sync::atomic::Ordering::Relaxed) < 3 {
416 eprintln!(
417 "[PMAT-420] backward: grad_hidden len={} nonzero={grad_nz} first5={:?}",
418 grad_hidden.len(),
419 &grad_hidden[..5.min(grad_hidden.len())]
420 );
421 }
422 self.backward_nf4_gpu_blocks(&grad_hidden, seq_len);
423 }
424
425 InstructStepResult {
426 loss: avg_loss,
427 num_response_tokens: num_loss_tokens,
428 perplexity: avg_loss.exp().min(1e6),
429 }
430 }
431 pub fn evaluate(
433 &self,
434 prompt_ids_batch: &[Vec<u32>],
435 response_ids_batch: &[Vec<u32>],
436 ) -> InstructBatchResult {
437 let mut total_loss = 0.0f32;
438 let mut total_response_tokens = 0usize;
439
440 for (prompt_ids, response_ids) in prompt_ids_batch.iter().zip(response_ids_batch.iter()) {
441 let full_ids: Vec<u32> =
442 prompt_ids.iter().chain(response_ids.iter()).copied().collect();
443
444 let prompt_len = prompt_ids.len();
445 if response_ids.is_empty() || full_ids.len() < 2 {
446 continue;
447 }
448
449 let full_ids = if full_ids.len() > self.config.max_seq_len {
450 full_ids[..self.config.max_seq_len].to_vec()
451 } else {
452 full_ids
453 };
454 let seq_len = full_ids.len();
455 let vocab_size = self.model.config().vocab_size;
456 let prompt_len = prompt_len.min(seq_len);
457
458 let logits = self.model.forward(&full_ids);
459 let logits_data = logits.data().as_slice().expect("contiguous logits").to_vec();
460
461 let loss_start = prompt_len.saturating_sub(1);
462 let loss_end = seq_len - 1;
463 let num_loss_tokens = loss_end.saturating_sub(loss_start);
464
465 let (sample_loss, _) = Self::compute_causal_lm_loss(
466 &logits_data,
467 &full_ids,
468 loss_start,
469 loss_end,
470 vocab_size,
471 );
472
473 total_loss += sample_loss * num_loss_tokens as f32;
474 total_response_tokens += num_loss_tokens;
475 }
476
477 let avg_loss =
478 if total_response_tokens > 0 { total_loss / total_response_tokens as f32 } else { 0.0 };
479
480 InstructBatchResult {
481 avg_loss,
482 total_response_tokens,
483 perplexity: avg_loss.exp().min(1e6),
484 grad_norm: 0.0,
485 }
486 }
487 pub(super) fn compute_causal_lm_loss(
491 logits_data: &[f32],
492 full_ids: &[u32],
493 loss_start: usize,
494 loss_end: usize,
495 vocab_size: usize,
496 ) -> (f32, Vec<f32>) {
497 let seq_len = full_ids.len();
498 let num_loss_tokens = loss_end.saturating_sub(loss_start);
499 let mut total_loss = 0.0f32;
500 let mut grad_logits = vec![0.0f32; seq_len * vocab_size];
501
502 for pos in loss_start..loss_end {
503 let target = full_ids[pos + 1] as usize;
504 if target >= vocab_size {
505 continue;
506 }
507
508 let logit_start = pos * vocab_size;
509 let row = &logits_data[logit_start..logit_start + vocab_size];
510
511 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
512 let grad_row = &mut grad_logits[logit_start..logit_start + vocab_size];
513 let mut sum_exp = 0.0f32;
514 for j in 0..vocab_size {
515 let exp_v = (row[j] - max_val).exp();
516 grad_row[j] = exp_v;
517 sum_exp += exp_v;
518 }
519
520 let log_sum_exp = sum_exp.ln() + max_val;
521 let loss_i = -(row[target] - log_sum_exp);
522 total_loss += if loss_i.is_finite() { loss_i } else { 100.0 };
523
524 let inv_n = 1.0 / num_loss_tokens as f32;
525 let scale = inv_n / sum_exp;
526 for j in 0..vocab_size {
527 grad_row[j] *= scale;
528 }
529 grad_row[target] -= inv_n;
530 }
531
532 let avg_loss = if num_loss_tokens > 0 { total_loss / num_loss_tokens as f32 } else { 0.0 };
533
534 (avg_loss, grad_logits)
535 }
536}