Skip to main content

entrenar/finetune/instruct_pipeline/
training.rs

1#[allow(clippy::wildcard_imports)]
2use super::*;
3
4#[cfg(feature = "cuda")]
5use crate::autograd::cuda_forward::gemm_forward;
6#[cfg(feature = "cuda")]
7use crate::autograd::cuda_optim::fused_causal_cross_entropy_cuda;
8
9impl InstructPipeline {
10    /// Compute causal LM loss on a single instruction-response pair.
11    ///
12    /// # Contract (F-INST-002)
13    /// Loss is computed only on response tokens. Prompt tokens are masked.
14    ///
15    /// When CUDA NF4 blocks are available, dispatches to GPU forward pass
16    /// with CPU loss computation and GPU backward/optimizer.
17    pub fn train_step(&mut self, prompt_ids: &[u32], response_ids: &[u32]) -> InstructStepResult {
18        self.profiler.begin_step();
19        let full_ids: Vec<u32> = prompt_ids.iter().chain(response_ids.iter()).copied().collect();
20
21        let prompt_len = prompt_ids.len();
22        let response_len = response_ids.len();
23
24        if response_len == 0 || full_ids.len() < 2 {
25            self.profiler.finish_step();
26            return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
27        }
28
29        let full_ids = if full_ids.len() > self.config.max_seq_len {
30            full_ids[..self.config.max_seq_len].to_vec()
31        } else {
32            full_ids
33        };
34        let seq_len = full_ids.len();
35        let vocab_size = self.model.config().vocab_size;
36
37        // Cap prompt_len at truncated sequence length. If the prompt alone
38        // exceeds max_seq_len, all response tokens were truncated away.
39        let prompt_len = prompt_len.min(seq_len);
40
41        // ── CUDA GPU path (NF4 QLoRA) ─────────────────────────────────
42        // ── CUDA GPU path (NF4 QLoRA) ─────────────────────────────────
43        // PMAT-420: Use CUDA path for ALL configs. On 8GB, the inference-style
44        // forward (fresh buffers, saves inputs) replaces the NaN-prone training forward.
45        #[cfg(feature = "cuda")]
46        if self.cuda_blocks.is_some() {
47            let result = self.cuda_train_step(&full_ids, prompt_len, seq_len, vocab_size);
48            self.profiler.finish_step();
49            return result;
50        }
51
52        // ── wgpu GPU path (§26 WgpuTrainingPipeline) ─────────────────
53        #[cfg(feature = "gpu")]
54        if self.wgpu_training.is_some() {
55            return self.wgpu_train_step(&full_ids, prompt_len, seq_len, vocab_size);
56        }
57
58        // ── CPU path ──────────────────────────────────────────────────
59
60        // 1. Zero gradients
61        for lora in &mut self.lora_layers {
62            for param in lora.trainable_params() {
63                param.zero_grad();
64            }
65        }
66
67        // 2. Forward pass → logits [seq_len, vocab_size]
68        let logits = self.model.forward(&full_ids);
69        let logits_data = logits.data().as_slice().expect("contiguous logits").to_vec();
70
71        // 3. Causal LM loss on response tokens only
72        let loss_start = prompt_len.saturating_sub(1);
73        let loss_end = seq_len - 1;
74        let num_loss_tokens = loss_end.saturating_sub(loss_start);
75
76        if num_loss_tokens == 0 {
77            return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
78        }
79
80        let (avg_loss, grad_logits) =
81            Self::compute_causal_lm_loss(&logits_data, &full_ids, loss_start, loss_end, vocab_size);
82
83        // 4. Backward through autograd
84        logits.set_grad(ndarray::Array1::from(grad_logits));
85        if let Some(op) = logits.backward_op() {
86            op.backward();
87        }
88
89        // 5. Optimizer step on LoRA parameters
90        let mut params: Vec<&mut Tensor> = Vec::new();
91        for lora in &mut self.lora_layers {
92            params.extend(lora.trainable_params());
93        }
94
95        if let Some(max_norm) = self.config.gradient_clip_norm {
96            clip_grad_norm_refs(&mut params, max_norm);
97        }
98
99        self.optimizer.step_refs(&mut params);
100
101        InstructStepResult {
102            loss: avg_loss,
103            num_response_tokens: num_loss_tokens,
104            perplexity: avg_loss.exp().min(1e6),
105        }
106    }
107    /// GPU-accelerated training step for NF4 QLoRA.
108    ///
109    /// 1. GPU forward through NF4 transformer blocks → normed hidden states
110    /// 2. CPU lm_head matmul → logits
111    /// 3. CPU causal LM loss on response tokens only
112    /// 4. CPU gradient of loss w.r.t. hidden states (through lm_head)
113    /// 5. GPU backward through NF4 blocks → LoRA gradient + optimizer step
114    #[cfg(feature = "cuda")]
115    fn cuda_train_step(
116        &mut self,
117        full_ids: &[u32],
118        prompt_len: usize,
119        seq_len: usize,
120        vocab_size: usize,
121    ) -> InstructStepResult {
122        // entrenar#318: truncate seq_len to match forward_cuda_training's max
123        let max_pos = self.model.config().max_position_embeddings.min(512);
124        let seq_len = seq_len.min(max_pos);
125        let prompt_len = prompt_len.min(seq_len);
126        let loss_start = prompt_len.saturating_sub(1);
127        let loss_end = seq_len - 1;
128        let num_loss_tokens = loss_end.saturating_sub(loss_start);
129
130        if num_loss_tokens == 0 {
131            return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
132        }
133
134        // PMAT-420: If GPU embeddings are minimal (VRAM-constrained), skip the GPU-resident
135        // logits path entirely — go straight to CPU-loss path which uses GPU transformer + CPU lm_head.
136        let has_gpu_embed = self.gpu_training.as_ref().is_some_and(|t| {
137            t.embed_original.len() >= self.model.config().hidden_size * vocab_size
138        });
139
140        if !has_gpu_embed {
141            return self.cuda_train_step_cpu_loss(
142                full_ids,
143                loss_start,
144                loss_end,
145                num_loss_tokens,
146                seq_len,
147                vocab_size,
148            );
149        }
150
151        // PMAT-483: Enable per-op profiling on scratch if profiler is active
152        if self.profiler.is_enabled() {
153            if let Some(ref mut scratch) = self.shared_scratch {
154                scratch.op_profiling_enabled = true;
155                scratch.op_us = [0u64; 16];
156            }
157        }
158
159        // 1. GPU forward → logits stay GPU-resident in training.logits_buf (KAIZEN-064)
160        self.profiler.begin(StepProfiler::FORWARD);
161        if !self.forward_logits_gpu_resident(full_ids) {
162            self.profiler.end(StepProfiler::FORWARD);
163            eprintln!("[CUDA] GPU forward failed, falling back to CPU for this step");
164            return self.cuda_train_step_cpu_loss(
165                full_ids,
166                loss_start,
167                loss_end,
168                num_loss_tokens,
169                seq_len,
170                vocab_size,
171            );
172        }
173        self.profiler.end(StepProfiler::FORWARD);
174
175        // 2. Fused GPU causal cross-entropy loss + softmax backward (KAIZEN-064)
176        let targets: Vec<u32> = (0..seq_len)
177            .map(|pos| if pos + 1 < full_ids.len() { full_ids[pos + 1] } else { 0 })
178            .collect();
179
180        let scale = 1.0 / num_loss_tokens as f32;
181
182        self.profiler.begin(StepProfiler::LOSS);
183        let avg_loss = (|| -> Option<f32> {
184            let trainer = self.cuda_trainer.as_ref()?;
185            let stream = trainer.stream();
186            let training = self.gpu_training.as_mut()?;
187            fused_causal_cross_entropy_cuda(
188                &mut training.logits_buf,
189                &targets,
190                seq_len as u32,
191                vocab_size as u32,
192                loss_start as u32,
193                loss_end as u32,
194                scale,
195                stream,
196            )
197            .ok()
198        })();
199        self.profiler.end(StepProfiler::LOSS);
200
201        let avg_loss = match avg_loss {
202            Some(l) if l.is_finite() => {
203                eprintln!("[CUDA] loss={l:.4} (finite, proceeding with backward)");
204                l
205            }
206            Some(l) => {
207                eprintln!("[CUDA] NaN/Inf loss detected (loss={l}) — skipping backward pass");
208                return InstructStepResult {
209                    loss: 100.0,
210                    num_response_tokens: num_loss_tokens,
211                    perplexity: 1e6,
212                };
213            }
214            None => {
215                eprintln!("[CUDA] fused causal cross-entropy failed — falling back to CPU");
216                return self.cuda_train_step_cpu_loss(
217                    full_ids,
218                    loss_start,
219                    loss_end,
220                    num_loss_tokens,
221                    seq_len,
222                    vocab_size,
223                );
224            }
225        };
226
227        // 3. GPU GEMM backward: grad_hidden = grad_logits @ embed (KAIZEN-064/065/068)
228        self.profiler.begin(StepProfiler::LM_BWD);
229        let hidden_size = self.model.config().hidden_size;
230
231        let gemm_ok = (|| -> Option<()> {
232            let trainer = self.cuda_trainer.as_ref()?;
233            let stream = trainer.stream();
234            let training = self.gpu_training.as_mut()?;
235            if training.embed_original.len() < vocab_size * hidden_size {
236                return None;
237            }
238            gemm_forward(
239                &training.logits_buf,
240                &training.embed_original,
241                &mut training.grad_hidden_buf,
242                seq_len as u32,
243                vocab_size as u32,
244                hidden_size as u32,
245                stream,
246            )
247            .map_err(|e| eprintln!("[CUDA] lm_head backward GEMM failed: {e}"))
248            .ok()?;
249            Some(())
250        })();
251
252        self.profiler.end(StepProfiler::LM_BWD);
253
254        if gemm_ok.is_none() {
255            // PMAT-471: CPU fallback when GPU embeddings don't fit
256            let cpu_ok = (|| -> Option<()> {
257                let trainer = self.cuda_trainer.as_ref()?;
258                let training = self.gpu_training.as_mut()?;
259                let embed = self.model.embed_tokens.weight.data();
260                let embed = embed.as_slice().expect("contiguous embed");
261                super::super::gpu_backward_fallback::cpu_lmhead_backward(
262                    trainer,
263                    &training.logits_buf,
264                    &mut training.grad_hidden_buf,
265                    embed,
266                    seq_len,
267                    vocab_size,
268                    hidden_size,
269                    trainer.stream(),
270                )
271            })();
272            if cpu_ok.is_none() {
273                return InstructStepResult {
274                    loss: avg_loss,
275                    num_response_tokens: num_loss_tokens,
276                    perplexity: avg_loss.exp().min(1e6),
277                };
278            }
279        }
280
281        // 4. GPU backward through NF4 blocks (KAIZEN-065: GPU-resident)
282        self.profiler.begin(StepProfiler::BLK_BWD);
283        if self.config.quantize_nf4 {
284            self.backward_nf4_gpu_blocks_gpu_resident(seq_len);
285        }
286        self.profiler.end(StepProfiler::BLK_BWD);
287
288        // PMAT-483: Feed per-layer timing from training state to profiler
289        if let Some(ref training) = self.gpu_training {
290            self.profiler.record_layer_times(
291                &training.profiler_layer_fwd_us,
292                &training.profiler_layer_bwd_us,
293            );
294        }
295
296        // PMAT-483/entrenar#328: Feed per-op timing from scratch to profiler
297        if let Some(ref scratch) = self.shared_scratch {
298            if scratch.op_profiling_enabled {
299                for (i, &us) in scratch.op_us.iter().enumerate() {
300                    if us > 0 {
301                        self.profiler.end_op_raw(i, us);
302                    }
303                }
304            }
305        }
306
307        InstructStepResult {
308            loss: avg_loss,
309            num_response_tokens: num_loss_tokens,
310            perplexity: avg_loss.exp().min(1e6),
311        }
312    }
313    /// CPU fallback for causal LM loss when GPU fused kernel is unavailable.
314    /// Used when forward_logits_gpu_resident or fused_causal_cross_entropy_cuda fails.
315    #[cfg(feature = "cuda")]
316    fn cuda_train_step_cpu_loss(
317        &mut self,
318        full_ids: &[u32],
319        loss_start: usize,
320        loss_end: usize,
321        num_loss_tokens: usize,
322        seq_len: usize,
323        vocab_size: usize,
324    ) -> InstructStepResult {
325        // PMAT-420: Check if GPU embeddings are available. If not (VRAM-constrained),
326        // skip forward_logits_gpu entirely to avoid CUDA context poisoning.
327        let has_gpu_embed = self.gpu_training.as_ref().is_some_and(|t| {
328            t.embed_original.len() >= vocab_size * self.model.config().hidden_size
329        });
330
331        let logits_data = if has_gpu_embed {
332            match self.forward_logits_gpu(full_ids) {
333                Some(data) => data,
334                None => {
335                    let logits = self.model.forward(full_ids);
336                    logits.data().as_slice().expect("contiguous logits").to_vec()
337                }
338            }
339        } else {
340            // PMAT-420: Inference-style forward + save inputs for backward
341            match self.forward_inference_saving_inputs(full_ids) {
342                Some(data) => data,
343                None => {
344                    let logits = self.model.forward(full_ids);
345                    logits.data().as_slice().expect("contiguous logits").to_vec()
346                }
347            }
348        };
349
350        let (avg_loss, grad_logits) =
351            Self::compute_causal_lm_loss(&logits_data, full_ids, loss_start, loss_end, vocab_size);
352
353        if !avg_loss.is_finite() {
354            return InstructStepResult {
355                loss: 100.0,
356                num_response_tokens: num_loss_tokens,
357                perplexity: 1e6,
358            };
359        }
360
361        let hidden_size = self.model.config().hidden_size;
362
363        let grad_hidden = (|| -> Option<Vec<f32>> {
364            let trainer = self.cuda_trainer.as_ref()?;
365            let stream = trainer.stream();
366            let training = self.gpu_training.as_mut()?;
367            if training.logits_buf.len() < grad_logits.len() {
368                return None;
369            }
370            training
371                .logits_buf
372                .copy_from_host_at(&grad_logits, 0)
373                .map_err(|e| eprintln!("[CUDA] lm_head backward: grad_logits upload failed: {e}"))
374                .ok()?;
375            if training.embed_original.len() < vocab_size * hidden_size {
376                return None;
377            }
378            gemm_forward(
379                &training.logits_buf,
380                &training.embed_original,
381                &mut training.grad_hidden_buf,
382                seq_len as u32,
383                vocab_size as u32,
384                hidden_size as u32,
385                stream,
386            )
387            .map_err(|e| eprintln!("[CUDA] lm_head backward GEMM failed: {e}"))
388            .ok()?;
389            stream.synchronize().ok()?;
390            let full_grad = trainer.download(&training.grad_hidden_buf).ok()?;
391            Some(full_grad[..seq_len * hidden_size].to_vec())
392        })();
393
394        let grad_hidden = match grad_hidden {
395            Some(g) => g,
396            None => {
397                let hidden_size = self.model.config().hidden_size;
398                let lm_weight =
399                    self.model.lm_head.as_ref().unwrap_or(&self.model.embed_tokens.weight);
400                let lm_data = lm_weight.data();
401                let lm_slice = lm_data.as_slice().expect("contiguous lm_head");
402                crate::autograd::ops::matmul::matmul_compute(
403                    &grad_logits[..seq_len * vocab_size],
404                    lm_slice,
405                    seq_len,
406                    vocab_size,
407                    hidden_size,
408                )
409            }
410        };
411
412        if self.config.quantize_nf4 {
413            let grad_nz = grad_hidden.iter().filter(|&&x| x != 0.0).count();
414            static BWD_LOG: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
415            if BWD_LOG.fetch_add(1, std::sync::atomic::Ordering::Relaxed) < 3 {
416                eprintln!(
417                    "[PMAT-420] backward: grad_hidden len={} nonzero={grad_nz} first5={:?}",
418                    grad_hidden.len(),
419                    &grad_hidden[..5.min(grad_hidden.len())]
420                );
421            }
422            self.backward_nf4_gpu_blocks(&grad_hidden, seq_len);
423        }
424
425        InstructStepResult {
426            loss: avg_loss,
427            num_response_tokens: num_loss_tokens,
428            perplexity: avg_loss.exp().min(1e6),
429        }
430    }
431    /// Evaluate loss and perplexity on a set of samples without updating weights.
432    pub fn evaluate(
433        &self,
434        prompt_ids_batch: &[Vec<u32>],
435        response_ids_batch: &[Vec<u32>],
436    ) -> InstructBatchResult {
437        let mut total_loss = 0.0f32;
438        let mut total_response_tokens = 0usize;
439
440        for (prompt_ids, response_ids) in prompt_ids_batch.iter().zip(response_ids_batch.iter()) {
441            let full_ids: Vec<u32> =
442                prompt_ids.iter().chain(response_ids.iter()).copied().collect();
443
444            let prompt_len = prompt_ids.len();
445            if response_ids.is_empty() || full_ids.len() < 2 {
446                continue;
447            }
448
449            let full_ids = if full_ids.len() > self.config.max_seq_len {
450                full_ids[..self.config.max_seq_len].to_vec()
451            } else {
452                full_ids
453            };
454            let seq_len = full_ids.len();
455            let vocab_size = self.model.config().vocab_size;
456            let prompt_len = prompt_len.min(seq_len);
457
458            let logits = self.model.forward(&full_ids);
459            let logits_data = logits.data().as_slice().expect("contiguous logits").to_vec();
460
461            let loss_start = prompt_len.saturating_sub(1);
462            let loss_end = seq_len - 1;
463            let num_loss_tokens = loss_end.saturating_sub(loss_start);
464
465            let (sample_loss, _) = Self::compute_causal_lm_loss(
466                &logits_data,
467                &full_ids,
468                loss_start,
469                loss_end,
470                vocab_size,
471            );
472
473            total_loss += sample_loss * num_loss_tokens as f32;
474            total_response_tokens += num_loss_tokens;
475        }
476
477        let avg_loss =
478            if total_response_tokens > 0 { total_loss / total_response_tokens as f32 } else { 0.0 };
479
480        InstructBatchResult {
481            avg_loss,
482            total_response_tokens,
483            perplexity: avg_loss.exp().min(1e6),
484            grad_norm: 0.0,
485        }
486    }
487    /// Compute causal LM loss and gradients for the given position range.
488    ///
489    /// Returns (average_loss, gradient_logits).
490    pub(super) fn compute_causal_lm_loss(
491        logits_data: &[f32],
492        full_ids: &[u32],
493        loss_start: usize,
494        loss_end: usize,
495        vocab_size: usize,
496    ) -> (f32, Vec<f32>) {
497        let seq_len = full_ids.len();
498        let num_loss_tokens = loss_end.saturating_sub(loss_start);
499        let mut total_loss = 0.0f32;
500        let mut grad_logits = vec![0.0f32; seq_len * vocab_size];
501
502        for pos in loss_start..loss_end {
503            let target = full_ids[pos + 1] as usize;
504            if target >= vocab_size {
505                continue;
506            }
507
508            let logit_start = pos * vocab_size;
509            let row = &logits_data[logit_start..logit_start + vocab_size];
510
511            let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
512            let grad_row = &mut grad_logits[logit_start..logit_start + vocab_size];
513            let mut sum_exp = 0.0f32;
514            for j in 0..vocab_size {
515                let exp_v = (row[j] - max_val).exp();
516                grad_row[j] = exp_v;
517                sum_exp += exp_v;
518            }
519
520            let log_sum_exp = sum_exp.ln() + max_val;
521            let loss_i = -(row[target] - log_sum_exp);
522            total_loss += if loss_i.is_finite() { loss_i } else { 100.0 };
523
524            let inv_n = 1.0 / num_loss_tokens as f32;
525            let scale = inv_n / sum_exp;
526            for j in 0..vocab_size {
527                grad_row[j] *= scale;
528            }
529            grad_row[target] -= inv_n;
530        }
531
532        let avg_loss = if num_loss_tokens > 0 { total_loss / num_loss_tokens as f32 } else { 0.0 };
533
534        (avg_loss, grad_logits)
535    }
536}