Skip to main content

entrenar/finetune/instruct_pipeline/
cuda_forward.rs

1#[cfg(feature = "cuda")]
2use super::{CudaBlockScratch, InstructGpuTrainingState, InstructPipeline, Transformer};
3
4#[cfg(feature = "cuda")]
5use crate::autograd::cuda_training::CudaTrainer;
6#[cfg(feature = "cuda")]
7use crate::transformer::CudaBlock;
8#[cfg(feature = "cuda")]
9use trueno_gpu::driver::{CaptureMode, GpuBuffer};
10
11#[cfg(feature = "cuda")]
12impl InstructPipeline {
13    /// GPU-accelerated forward pass saving layer inputs for backward.
14    #[allow(unsafe_code)]
15    pub(super) fn forward_cuda_training(
16        model: &Transformer,
17        token_ids: &[u32],
18        trainer: &CudaTrainer,
19        cuda_blocks: &mut [CudaBlock],
20        training_state: &mut InstructGpuTrainingState,
21        shared_scratch: &mut Option<CudaBlockScratch>,
22    ) -> Option<()> {
23        let seq_len = token_ids.len();
24        let hidden_size = model.config.hidden_size;
25        let max_seq_len = shared_scratch
26            .as_ref()
27            .map_or(model.config.max_position_embeddings.min(512), |s| s.max_seq_len(hidden_size));
28        let seq_len = if seq_len > max_seq_len { max_seq_len } else { seq_len };
29        if seq_len == 0 {
30            return None;
31        }
32
33        // Embed on CPU, upload to GPU
34        let hidden = model.embed_tokens.forward(token_ids);
35        let hidden_data = hidden.data();
36        let hidden_slice = hidden_data.as_slice().expect("contiguous hidden");
37
38        // PMAT-420 / entrenar#316: Use seq_len-sized fresh buffers (like inference forward).
39        training_state.fwd_scratch_a = trainer
40            .upload(hidden_slice)
41            .map_err(|e| eprintln!("[CUDA] embed upload failed: {e}"))
42            .ok()?;
43        training_state.fwd_scratch_b = trainer
44            .zeros(seq_len * hidden_size)
45            .map_err(|e| eprintln!("[CUDA] scratch_b alloc failed: {e}"))
46            .ok()?;
47
48        let scratch_a_ptr: *mut GpuBuffer<f32> =
49            std::ptr::from_mut(&mut training_state.fwd_scratch_a);
50        let scratch_b_ptr: *mut GpuBuffer<f32> =
51            std::ptr::from_mut(&mut training_state.fwd_scratch_b);
52        let mut input_is_a = true;
53
54        let stream = trainer.stream();
55        // entrenar#318: GPU-side scratch + training state zeroing (PMAT-453 NaN cascade fix).
56        if let Some(ref mut scratch) = shared_scratch.as_mut() {
57            scratch.zero_forward_buffers(stream);
58        }
59        for b in [
60            &mut training_state.grad_buf_a,
61            &mut training_state.grad_buf_b,
62            &mut training_state.grad_hidden_buf,
63            &mut training_state.output_scratch,
64            &mut training_state.logits_buf,
65        ] {
66            b.zero_async(stream).ok();
67        }
68
69        // PMAT-464: CUDA graph capture/replay (CUDA_GRAPH=1)
70        static USE_CUDA_GRAPH: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
71        let use_graph =
72            *USE_CUDA_GRAPH.get_or_init(|| std::env::var("CUDA_GRAPH").as_deref() == Ok("1"));
73
74        for (i, block_) in cuda_blocks.iter().enumerate() {
75            let _ = block_;
76            let expected_len = seq_len * hidden_size;
77            if training_state.layer_inputs[i].len() != expected_len {
78                training_state.layer_inputs[i] = trainer
79                    .zeros(expected_len)
80                    .map_err(|e| eprintln!("[CUDA] layer_input prealloc L{i}: {e}"))
81                    .ok()?;
82            }
83        }
84
85        if use_graph
86            && training_state.graph_cached_seq_len == seq_len
87            && training_state.forward_graph_exec.is_some()
88        {
89            // === GRAPH REPLAY ===
90            let exec = training_state.forward_graph_exec.as_ref().unwrap();
91            exec.launch(stream.raw())
92                .map_err(|e| eprintln!("[CUDA] Graph replay failed: {e}"))
93                .ok()?;
94            for _ in 0..cuda_blocks.len() {
95                input_is_a = !input_is_a;
96            }
97        } else {
98            // === Standard or first-capture forward ===
99            let capturing = use_graph && training_state.graph_cached_seq_len != seq_len;
100            if capturing {
101                // PMAT-063: Pre-allocate cuBLAS workspace before graph capture
102                if training_state.cublas_workspace.is_none() {
103                    training_state.cublas_workspace =
104                        super::super::gpu_backward_fallback::preallocate_cublas_workspace(trainer);
105                }
106                stream
107                    .begin_capture(CaptureMode::ThreadLocal)
108                    .map_err(|e| eprintln!("[CUDA] Graph capture begin failed: {e}"))
109                    .ok()?;
110            }
111
112            for (i, block) in cuda_blocks.iter_mut().enumerate() {
113                let (gpu_input, gpu_output) = unsafe {
114                    if input_is_a {
115                        (&*scratch_a_ptr, &mut *scratch_b_ptr)
116                    } else {
117                        (&*scratch_b_ptr, &mut *scratch_a_ptr)
118                    }
119                };
120
121                training_state.layer_inputs[i]
122                    .copy_from_buffer(gpu_input)
123                    .map_err(|e| eprintln!("[CUDA] layer_input copy L{i}: {e}"))
124                    .ok()?;
125
126                // PMAT-483: Per-layer forward profiling
127                training_state.profiler_layer_start = Some(std::time::Instant::now());
128
129                if let Err(e) =
130                    block.forward(gpu_input, gpu_output, seq_len, stream, shared_scratch.as_mut())
131                {
132                    eprintln!(
133                        "[CUDA] Layer {i} forward failed: {e} (seq_len={seq_len} in={} out={} hidden={hidden_size})",
134                        gpu_input.len(), gpu_output.len(),
135                    );
136                    if capturing {
137                        let _ = stream.end_capture();
138                    }
139                    return None;
140                }
141
142                // PMAT-483: Record per-layer forward time
143                if let Some(start) = training_state.profiler_layer_start.take() {
144                    training_state.profiler_layer_fwd_us[i] = start.elapsed().as_micros() as u64;
145                }
146
147                input_is_a = !input_is_a;
148            }
149
150            if capturing {
151                match stream.end_capture() {
152                    Ok(graph) => match graph.instantiate() {
153                        Ok(exec) => {
154                            eprintln!(
155                                "[CUDA] Graph captured: {} layers, seq_len={seq_len}",
156                                cuda_blocks.len()
157                            );
158                            training_state.forward_graph_exec = Some(exec);
159                            training_state.graph_cached_seq_len = seq_len;
160                        }
161                        Err(e) => {
162                            eprintln!(
163                                "[CUDA] Graph instantiate failed: {e} — using non-graph path"
164                            );
165                        }
166                    },
167                    Err(e) => {
168                        eprintln!("[CUDA] Graph end_capture failed: {e} — using non-graph path");
169                    }
170                }
171            }
172        }
173
174        let final_output = unsafe {
175            if input_is_a {
176                &*scratch_a_ptr
177            } else {
178                &*scratch_b_ptr
179            }
180        };
181
182        // Save blocks output for RMSNorm backward
183        if training_state.blocks_output.len() != final_output.len() {
184            training_state.blocks_output = trainer
185                .zeros(final_output.len())
186                .map_err(|e| eprintln!("[CUDA] blocks_output realloc failed: {e}"))
187                .ok()?;
188        }
189        training_state
190            .blocks_output
191            .copy_from_buffer(final_output)
192            .map_err(|e| eprintln!("[CUDA] blocks_output copy: {e}"))
193            .ok()?;
194
195        crate::autograd::cuda_backward::rms_norm_forward(
196            final_output,
197            &training_state.final_norm_weight,
198            &mut training_state.lm_head_hidden_buf,
199            seq_len as u32,
200            hidden_size as u32,
201            stream,
202        )
203        .map_err(|e| eprintln!("[CUDA] GPU RMSNorm forward failed: {e}"))
204        .ok()?;
205
206        Some(())
207    }
208    /// GPU-accelerated forward pass (inference-only, no layer input saving).
209    pub(super) fn forward_cuda_inference(
210        model: &Transformer,
211        token_ids: &[u32],
212        trainer: &CudaTrainer,
213        cuda_blocks: &mut [CudaBlock],
214        shared_scratch: &mut Option<CudaBlockScratch>,
215    ) -> Option<Vec<f32>> {
216        let seq_len = token_ids.len();
217        let hidden_size = model.config.hidden_size;
218
219        let hidden = model.embed_tokens.forward(token_ids);
220        let hidden_data = hidden.data();
221        let hidden_slice = hidden_data.as_slice().expect("contiguous hidden");
222
223        let mut gpu_input = trainer.upload(hidden_slice).ok()?;
224        let mut gpu_output = trainer.zeros(seq_len * hidden_size).ok()?;
225
226        let stream = trainer.stream();
227        for (i, block) in cuda_blocks.iter_mut().enumerate() {
228            if let Err(e) =
229                block.forward(&gpu_input, &mut gpu_output, seq_len, stream, shared_scratch.as_mut())
230            {
231                eprintln!("[CUDA] Layer {i} forward failed: {e}");
232                return None;
233            }
234            std::mem::swap(&mut gpu_input, &mut gpu_output);
235        }
236
237        if let Err(e) = stream.synchronize() {
238            eprintln!("[CUDA] Stream sync failed: {e}");
239            return None;
240        }
241
242        let result_data = trainer.download(&gpu_input).ok()?;
243        if result_data.iter().any(|v| !v.is_finite()) {
244            return None;
245        }
246
247        let result_tensor = crate::Tensor::from_vec(result_data, false);
248        let normed = model.norm.forward_batched(&result_tensor, seq_len, hidden_size);
249        let normed_data = normed.data();
250        let normed_slice = normed_data.as_slice().expect("contiguous normed");
251        Some(normed_slice.to_vec())
252    }
253    /// Forward pass dispatching to GPU. Returns logits as flat Vec<f32> [seq_len, vocab_size].
254    /// lm_head GEMM runs on GPU: hidden[seq, hidden] @ embed_T[hidden, vocab] -> logits[seq, vocab]
255    pub(super) fn forward_logits_gpu(&mut self, token_ids: &[u32]) -> Option<Vec<f32>> {
256        let seq_len = token_ids.len();
257        let vocab_size = self.model.config().vocab_size;
258        let hidden_size = self.model.config().hidden_size;
259
260        if self.gpu_training.is_some() {
261            let (trainer, blocks) = match (&self.cuda_trainer, &mut self.cuda_blocks) {
262                (Some(ref t), Some(ref mut b)) => (t, b),
263                _ => return None,
264            };
265            let mut training = self.gpu_training.take();
266            let result = Self::forward_cuda_training(
267                &self.model,
268                token_ids,
269                trainer,
270                blocks,
271                training.as_mut().expect("gpu_training was Some"),
272                &mut self.shared_scratch,
273            );
274            self.gpu_training = training;
275            result?;
276        } else {
277            let (trainer, blocks) = match (&self.cuda_trainer, &mut self.cuda_blocks) {
278                (Some(ref t), Some(ref mut b)) => (t, b),
279                _ => return None,
280            };
281            let normed_hidden = Self::forward_cuda_inference(
282                &self.model,
283                token_ids,
284                trainer,
285                blocks,
286                &mut self.shared_scratch,
287            )?;
288            let training = self.gpu_training.as_mut()?;
289            training
290                .lm_head_hidden_buf
291                .copy_from_host_at(&normed_hidden, 0)
292                .map_err(|e| eprintln!("[CUDA] lm_head forward: hidden upload failed: {e}"))
293                .ok()?;
294        }
295
296        let trainer = self.cuda_trainer.as_ref()?;
297        let training = self.gpu_training.as_mut()?;
298        let stream = trainer.stream();
299
300        eprintln!("[CUDA] lm_head BT: hidden_len={} embed_len={} logits_len={} seq={seq_len} h={hidden_size} v={vocab_size}",
301            training.lm_head_hidden_buf.len(), training.embed_original.len(), training.logits_buf.len());
302        if let Err(e) = crate::autograd::cuda_forward::gemm_forward_bt(
303            &training.lm_head_hidden_buf,
304            &training.embed_original,
305            &mut training.logits_buf,
306            seq_len as u32,
307            hidden_size as u32,
308            vocab_size as u32,
309            stream,
310        ) {
311            eprintln!("[CUDA] lm_head forward GEMM (BT) failed: {e}");
312            return None;
313        }
314
315        if let Err(e) = stream.synchronize() {
316            eprintln!("[CUDA] lm_head forward sync failed: {e}");
317            return None;
318        }
319
320        let full_logits = trainer
321            .download(&training.logits_buf)
322            .map_err(|e| eprintln!("[CUDA] lm_head forward: logits download failed: {e}"))
323            .ok()?;
324        Some(full_logits[..seq_len * vocab_size].to_vec())
325    }
326    /// PMAT-420: Inference forward + save layer inputs for backward.
327    /// Uses inference-style fresh buffers (no NaN) but saves layer inputs for GPU backward.
328    pub(super) fn forward_inference_saving_inputs(
329        &mut self,
330        token_ids: &[u32],
331    ) -> Option<Vec<f32>> {
332        let seq_len = token_ids.len();
333        let hidden_size = self.model.config().hidden_size;
334        let vocab_size = self.model.config().vocab_size;
335
336        let trainer = self.cuda_trainer.as_ref()?;
337        let blocks = self.cuda_blocks.as_mut()?;
338        let stream = trainer.stream();
339
340        let hidden = self.model.embed_tokens.forward(token_ids);
341        let hidden_data = hidden.data();
342        let hidden_slice = hidden_data.as_slice().expect("contiguous hidden");
343
344        let mut gpu_input = trainer.upload(hidden_slice).ok()?;
345        let mut gpu_output = trainer.zeros(seq_len * hidden_size).ok()?;
346
347        for (i, block) in blocks.iter_mut().enumerate() {
348            if let Some(ref mut training) = self.gpu_training {
349                if i < training.layer_inputs.len() {
350                    if training.layer_inputs[i].len() != gpu_input.len() {
351                        if let Ok(buf) = trainer.zeros(gpu_input.len()) {
352                            training.layer_inputs[i] = buf;
353                        }
354                    }
355                    training.layer_inputs[i]
356                        .copy_from_buffer(&gpu_input)
357                        .map_err(|e| eprintln!("[CUDA] layer_input copy L{i}: {e}"))
358                        .ok();
359                }
360            }
361
362            if let Err(e) = block.forward(
363                &gpu_input,
364                &mut gpu_output,
365                seq_len,
366                stream,
367                self.shared_scratch.as_mut(),
368            ) {
369                eprintln!("[CUDA] Layer {i} forward failed: {e}");
370                return None;
371            }
372            std::mem::swap(&mut gpu_input, &mut gpu_output);
373        }
374
375        stream.synchronize().ok()?;
376
377        // Save blocks_output for RMSNorm backward
378        if let Some(ref mut training) = self.gpu_training {
379            if training.blocks_output.len() != gpu_input.len() {
380                if let Ok(buf) = trainer.zeros(gpu_input.len()) {
381                    training.blocks_output = buf;
382                }
383            }
384            training
385                .blocks_output
386                .copy_from_buffer(&gpu_input)
387                .map_err(|e| eprintln!("[CUDA] blocks_output copy: {e}"))
388                .ok();
389        }
390
391        let result = trainer.download(&gpu_input).ok()?;
392        if result.iter().any(|v| !v.is_finite()) {
393            eprintln!("[CUDA] NaN in forward output — inference-style forward failed");
394            return None;
395        }
396
397        // CPU RMSNorm
398        let result_tensor = crate::autograd::Tensor::from_vec(result, false);
399        let normed = self.model.norm.forward_batched(&result_tensor, seq_len, hidden_size);
400        let normed_data = normed.data();
401        let normed_slice = normed_data.as_slice().expect("contiguous normed");
402
403        // Save normed hidden for lm_head backward
404        if let Some(ref mut training) = self.gpu_training {
405            if let Ok(buf) = trainer.upload(normed_slice) {
406                training.lm_head_hidden_buf = buf;
407            }
408        }
409
410        // CPU lm_head
411        let lm_weight = self.model.lm_head.as_ref().unwrap_or(&self.model.embed_tokens.weight);
412        let lm_data = lm_weight.data();
413        let lm_slice = lm_data.as_slice().expect("contiguous lm_head");
414        let logits = crate::autograd::ops::matmul::matmul_nt_compute(
415            normed_slice,
416            lm_slice,
417            seq_len,
418            hidden_size,
419            vocab_size,
420        );
421        Some(logits)
422    }
423    /// GPU forward with logits staying GPU-resident (KAIZEN-064).
424    /// After this call, `training.logits_buf` contains logits on GPU. Returns true on success.
425    pub(super) fn forward_logits_gpu_resident(&mut self, token_ids: &[u32]) -> bool {
426        let seq_len = token_ids.len();
427        let vocab_size = self.model.config().vocab_size;
428        let hidden_size = self.model.config().hidden_size;
429
430        if self.gpu_training.is_some() {
431            let (trainer, blocks) = match (&self.cuda_trainer, &mut self.cuda_blocks) {
432                (Some(ref t), Some(ref mut b)) => (t, b),
433                _ => {
434                    eprintln!("[RES-FALSE] no trainer/blocks");
435                    return false;
436                }
437            };
438            let mut training = self.gpu_training.take();
439            let result = Self::forward_cuda_training(
440                &self.model,
441                token_ids,
442                trainer,
443                blocks,
444                training.as_mut().expect("gpu_training was Some"),
445                &mut self.shared_scratch,
446            );
447            self.gpu_training = training;
448            if result.is_none() {
449                eprintln!("[RES-FALSE] forward_cuda_training returned None");
450                return false;
451            }
452        } else {
453            let (trainer, blocks) = match (&self.cuda_trainer, &mut self.cuda_blocks) {
454                (Some(ref t), Some(ref mut b)) => (t, b),
455                _ => return false,
456            };
457            let normed_hidden = match Self::forward_cuda_inference(
458                &self.model,
459                token_ids,
460                trainer,
461                blocks,
462                &mut self.shared_scratch,
463            ) {
464                Some(h) => h,
465                None => return false,
466            };
467            let training = match self.gpu_training.as_mut() {
468                Some(t) => t,
469                None => return false,
470            };
471            if training.lm_head_hidden_buf.copy_from_host_at(&normed_hidden, 0).is_err() {
472                eprintln!("[CUDA] lm_head forward: hidden upload failed");
473                return false;
474            }
475        }
476
477        let (trainer, training) = match (&self.cuda_trainer, &mut self.gpu_training) {
478            (Some(ref t), Some(ref mut tr)) => (t, tr),
479            _ => {
480                eprintln!("[RES-FALSE] no trainer/training");
481                return false;
482            }
483        };
484
485        let stream = trainer.stream();
486
487        if crate::autograd::cuda_forward::gemm_forward_bt(
488            &training.lm_head_hidden_buf,
489            &training.embed_original,
490            &mut training.logits_buf,
491            seq_len as u32,
492            hidden_size as u32,
493            vocab_size as u32,
494            stream,
495        )
496        .is_err()
497        {
498            eprintln!("[CUDA] lm_head forward GEMM (BT) failed");
499            eprintln!("[RES-FALSE] BT GEMM failed");
500            return false;
501        }
502
503        true
504    }
505}