entrenar/finetune/instruct_pipeline/
cuda_forward.rs1#[cfg(feature = "cuda")]
2use super::{CudaBlockScratch, InstructGpuTrainingState, InstructPipeline, Transformer};
3
4#[cfg(feature = "cuda")]
5use crate::autograd::cuda_training::CudaTrainer;
6#[cfg(feature = "cuda")]
7use crate::transformer::CudaBlock;
8#[cfg(feature = "cuda")]
9use trueno_gpu::driver::{CaptureMode, GpuBuffer};
10
11#[cfg(feature = "cuda")]
12impl InstructPipeline {
13 #[allow(unsafe_code)]
15 pub(super) fn forward_cuda_training(
16 model: &Transformer,
17 token_ids: &[u32],
18 trainer: &CudaTrainer,
19 cuda_blocks: &mut [CudaBlock],
20 training_state: &mut InstructGpuTrainingState,
21 shared_scratch: &mut Option<CudaBlockScratch>,
22 ) -> Option<()> {
23 let seq_len = token_ids.len();
24 let hidden_size = model.config.hidden_size;
25 let max_seq_len = shared_scratch
26 .as_ref()
27 .map_or(model.config.max_position_embeddings.min(512), |s| s.max_seq_len(hidden_size));
28 let seq_len = if seq_len > max_seq_len { max_seq_len } else { seq_len };
29 if seq_len == 0 {
30 return None;
31 }
32
33 let hidden = model.embed_tokens.forward(token_ids);
35 let hidden_data = hidden.data();
36 let hidden_slice = hidden_data.as_slice().expect("contiguous hidden");
37
38 training_state.fwd_scratch_a = trainer
40 .upload(hidden_slice)
41 .map_err(|e| eprintln!("[CUDA] embed upload failed: {e}"))
42 .ok()?;
43 training_state.fwd_scratch_b = trainer
44 .zeros(seq_len * hidden_size)
45 .map_err(|e| eprintln!("[CUDA] scratch_b alloc failed: {e}"))
46 .ok()?;
47
48 let scratch_a_ptr: *mut GpuBuffer<f32> =
49 std::ptr::from_mut(&mut training_state.fwd_scratch_a);
50 let scratch_b_ptr: *mut GpuBuffer<f32> =
51 std::ptr::from_mut(&mut training_state.fwd_scratch_b);
52 let mut input_is_a = true;
53
54 let stream = trainer.stream();
55 if let Some(ref mut scratch) = shared_scratch.as_mut() {
57 scratch.zero_forward_buffers(stream);
58 }
59 for b in [
60 &mut training_state.grad_buf_a,
61 &mut training_state.grad_buf_b,
62 &mut training_state.grad_hidden_buf,
63 &mut training_state.output_scratch,
64 &mut training_state.logits_buf,
65 ] {
66 b.zero_async(stream).ok();
67 }
68
69 static USE_CUDA_GRAPH: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
71 let use_graph =
72 *USE_CUDA_GRAPH.get_or_init(|| std::env::var("CUDA_GRAPH").as_deref() == Ok("1"));
73
74 for (i, block_) in cuda_blocks.iter().enumerate() {
75 let _ = block_;
76 let expected_len = seq_len * hidden_size;
77 if training_state.layer_inputs[i].len() != expected_len {
78 training_state.layer_inputs[i] = trainer
79 .zeros(expected_len)
80 .map_err(|e| eprintln!("[CUDA] layer_input prealloc L{i}: {e}"))
81 .ok()?;
82 }
83 }
84
85 if use_graph
86 && training_state.graph_cached_seq_len == seq_len
87 && training_state.forward_graph_exec.is_some()
88 {
89 let exec = training_state.forward_graph_exec.as_ref().unwrap();
91 exec.launch(stream.raw())
92 .map_err(|e| eprintln!("[CUDA] Graph replay failed: {e}"))
93 .ok()?;
94 for _ in 0..cuda_blocks.len() {
95 input_is_a = !input_is_a;
96 }
97 } else {
98 let capturing = use_graph && training_state.graph_cached_seq_len != seq_len;
100 if capturing {
101 if training_state.cublas_workspace.is_none() {
103 training_state.cublas_workspace =
104 super::super::gpu_backward_fallback::preallocate_cublas_workspace(trainer);
105 }
106 stream
107 .begin_capture(CaptureMode::ThreadLocal)
108 .map_err(|e| eprintln!("[CUDA] Graph capture begin failed: {e}"))
109 .ok()?;
110 }
111
112 for (i, block) in cuda_blocks.iter_mut().enumerate() {
113 let (gpu_input, gpu_output) = unsafe {
114 if input_is_a {
115 (&*scratch_a_ptr, &mut *scratch_b_ptr)
116 } else {
117 (&*scratch_b_ptr, &mut *scratch_a_ptr)
118 }
119 };
120
121 training_state.layer_inputs[i]
122 .copy_from_buffer(gpu_input)
123 .map_err(|e| eprintln!("[CUDA] layer_input copy L{i}: {e}"))
124 .ok()?;
125
126 training_state.profiler_layer_start = Some(std::time::Instant::now());
128
129 if let Err(e) =
130 block.forward(gpu_input, gpu_output, seq_len, stream, shared_scratch.as_mut())
131 {
132 eprintln!(
133 "[CUDA] Layer {i} forward failed: {e} (seq_len={seq_len} in={} out={} hidden={hidden_size})",
134 gpu_input.len(), gpu_output.len(),
135 );
136 if capturing {
137 let _ = stream.end_capture();
138 }
139 return None;
140 }
141
142 if let Some(start) = training_state.profiler_layer_start.take() {
144 training_state.profiler_layer_fwd_us[i] = start.elapsed().as_micros() as u64;
145 }
146
147 input_is_a = !input_is_a;
148 }
149
150 if capturing {
151 match stream.end_capture() {
152 Ok(graph) => match graph.instantiate() {
153 Ok(exec) => {
154 eprintln!(
155 "[CUDA] Graph captured: {} layers, seq_len={seq_len}",
156 cuda_blocks.len()
157 );
158 training_state.forward_graph_exec = Some(exec);
159 training_state.graph_cached_seq_len = seq_len;
160 }
161 Err(e) => {
162 eprintln!(
163 "[CUDA] Graph instantiate failed: {e} — using non-graph path"
164 );
165 }
166 },
167 Err(e) => {
168 eprintln!("[CUDA] Graph end_capture failed: {e} — using non-graph path");
169 }
170 }
171 }
172 }
173
174 let final_output = unsafe {
175 if input_is_a {
176 &*scratch_a_ptr
177 } else {
178 &*scratch_b_ptr
179 }
180 };
181
182 if training_state.blocks_output.len() != final_output.len() {
184 training_state.blocks_output = trainer
185 .zeros(final_output.len())
186 .map_err(|e| eprintln!("[CUDA] blocks_output realloc failed: {e}"))
187 .ok()?;
188 }
189 training_state
190 .blocks_output
191 .copy_from_buffer(final_output)
192 .map_err(|e| eprintln!("[CUDA] blocks_output copy: {e}"))
193 .ok()?;
194
195 crate::autograd::cuda_backward::rms_norm_forward(
196 final_output,
197 &training_state.final_norm_weight,
198 &mut training_state.lm_head_hidden_buf,
199 seq_len as u32,
200 hidden_size as u32,
201 stream,
202 )
203 .map_err(|e| eprintln!("[CUDA] GPU RMSNorm forward failed: {e}"))
204 .ok()?;
205
206 Some(())
207 }
208 pub(super) fn forward_cuda_inference(
210 model: &Transformer,
211 token_ids: &[u32],
212 trainer: &CudaTrainer,
213 cuda_blocks: &mut [CudaBlock],
214 shared_scratch: &mut Option<CudaBlockScratch>,
215 ) -> Option<Vec<f32>> {
216 let seq_len = token_ids.len();
217 let hidden_size = model.config.hidden_size;
218
219 let hidden = model.embed_tokens.forward(token_ids);
220 let hidden_data = hidden.data();
221 let hidden_slice = hidden_data.as_slice().expect("contiguous hidden");
222
223 let mut gpu_input = trainer.upload(hidden_slice).ok()?;
224 let mut gpu_output = trainer.zeros(seq_len * hidden_size).ok()?;
225
226 let stream = trainer.stream();
227 for (i, block) in cuda_blocks.iter_mut().enumerate() {
228 if let Err(e) =
229 block.forward(&gpu_input, &mut gpu_output, seq_len, stream, shared_scratch.as_mut())
230 {
231 eprintln!("[CUDA] Layer {i} forward failed: {e}");
232 return None;
233 }
234 std::mem::swap(&mut gpu_input, &mut gpu_output);
235 }
236
237 if let Err(e) = stream.synchronize() {
238 eprintln!("[CUDA] Stream sync failed: {e}");
239 return None;
240 }
241
242 let result_data = trainer.download(&gpu_input).ok()?;
243 if result_data.iter().any(|v| !v.is_finite()) {
244 return None;
245 }
246
247 let result_tensor = crate::Tensor::from_vec(result_data, false);
248 let normed = model.norm.forward_batched(&result_tensor, seq_len, hidden_size);
249 let normed_data = normed.data();
250 let normed_slice = normed_data.as_slice().expect("contiguous normed");
251 Some(normed_slice.to_vec())
252 }
253 pub(super) fn forward_logits_gpu(&mut self, token_ids: &[u32]) -> Option<Vec<f32>> {
256 let seq_len = token_ids.len();
257 let vocab_size = self.model.config().vocab_size;
258 let hidden_size = self.model.config().hidden_size;
259
260 if self.gpu_training.is_some() {
261 let (trainer, blocks) = match (&self.cuda_trainer, &mut self.cuda_blocks) {
262 (Some(ref t), Some(ref mut b)) => (t, b),
263 _ => return None,
264 };
265 let mut training = self.gpu_training.take();
266 let result = Self::forward_cuda_training(
267 &self.model,
268 token_ids,
269 trainer,
270 blocks,
271 training.as_mut().expect("gpu_training was Some"),
272 &mut self.shared_scratch,
273 );
274 self.gpu_training = training;
275 result?;
276 } else {
277 let (trainer, blocks) = match (&self.cuda_trainer, &mut self.cuda_blocks) {
278 (Some(ref t), Some(ref mut b)) => (t, b),
279 _ => return None,
280 };
281 let normed_hidden = Self::forward_cuda_inference(
282 &self.model,
283 token_ids,
284 trainer,
285 blocks,
286 &mut self.shared_scratch,
287 )?;
288 let training = self.gpu_training.as_mut()?;
289 training
290 .lm_head_hidden_buf
291 .copy_from_host_at(&normed_hidden, 0)
292 .map_err(|e| eprintln!("[CUDA] lm_head forward: hidden upload failed: {e}"))
293 .ok()?;
294 }
295
296 let trainer = self.cuda_trainer.as_ref()?;
297 let training = self.gpu_training.as_mut()?;
298 let stream = trainer.stream();
299
300 eprintln!("[CUDA] lm_head BT: hidden_len={} embed_len={} logits_len={} seq={seq_len} h={hidden_size} v={vocab_size}",
301 training.lm_head_hidden_buf.len(), training.embed_original.len(), training.logits_buf.len());
302 if let Err(e) = crate::autograd::cuda_forward::gemm_forward_bt(
303 &training.lm_head_hidden_buf,
304 &training.embed_original,
305 &mut training.logits_buf,
306 seq_len as u32,
307 hidden_size as u32,
308 vocab_size as u32,
309 stream,
310 ) {
311 eprintln!("[CUDA] lm_head forward GEMM (BT) failed: {e}");
312 return None;
313 }
314
315 if let Err(e) = stream.synchronize() {
316 eprintln!("[CUDA] lm_head forward sync failed: {e}");
317 return None;
318 }
319
320 let full_logits = trainer
321 .download(&training.logits_buf)
322 .map_err(|e| eprintln!("[CUDA] lm_head forward: logits download failed: {e}"))
323 .ok()?;
324 Some(full_logits[..seq_len * vocab_size].to_vec())
325 }
326 pub(super) fn forward_inference_saving_inputs(
329 &mut self,
330 token_ids: &[u32],
331 ) -> Option<Vec<f32>> {
332 let seq_len = token_ids.len();
333 let hidden_size = self.model.config().hidden_size;
334 let vocab_size = self.model.config().vocab_size;
335
336 let trainer = self.cuda_trainer.as_ref()?;
337 let blocks = self.cuda_blocks.as_mut()?;
338 let stream = trainer.stream();
339
340 let hidden = self.model.embed_tokens.forward(token_ids);
341 let hidden_data = hidden.data();
342 let hidden_slice = hidden_data.as_slice().expect("contiguous hidden");
343
344 let mut gpu_input = trainer.upload(hidden_slice).ok()?;
345 let mut gpu_output = trainer.zeros(seq_len * hidden_size).ok()?;
346
347 for (i, block) in blocks.iter_mut().enumerate() {
348 if let Some(ref mut training) = self.gpu_training {
349 if i < training.layer_inputs.len() {
350 if training.layer_inputs[i].len() != gpu_input.len() {
351 if let Ok(buf) = trainer.zeros(gpu_input.len()) {
352 training.layer_inputs[i] = buf;
353 }
354 }
355 training.layer_inputs[i]
356 .copy_from_buffer(&gpu_input)
357 .map_err(|e| eprintln!("[CUDA] layer_input copy L{i}: {e}"))
358 .ok();
359 }
360 }
361
362 if let Err(e) = block.forward(
363 &gpu_input,
364 &mut gpu_output,
365 seq_len,
366 stream,
367 self.shared_scratch.as_mut(),
368 ) {
369 eprintln!("[CUDA] Layer {i} forward failed: {e}");
370 return None;
371 }
372 std::mem::swap(&mut gpu_input, &mut gpu_output);
373 }
374
375 stream.synchronize().ok()?;
376
377 if let Some(ref mut training) = self.gpu_training {
379 if training.blocks_output.len() != gpu_input.len() {
380 if let Ok(buf) = trainer.zeros(gpu_input.len()) {
381 training.blocks_output = buf;
382 }
383 }
384 training
385 .blocks_output
386 .copy_from_buffer(&gpu_input)
387 .map_err(|e| eprintln!("[CUDA] blocks_output copy: {e}"))
388 .ok();
389 }
390
391 let result = trainer.download(&gpu_input).ok()?;
392 if result.iter().any(|v| !v.is_finite()) {
393 eprintln!("[CUDA] NaN in forward output — inference-style forward failed");
394 return None;
395 }
396
397 let result_tensor = crate::autograd::Tensor::from_vec(result, false);
399 let normed = self.model.norm.forward_batched(&result_tensor, seq_len, hidden_size);
400 let normed_data = normed.data();
401 let normed_slice = normed_data.as_slice().expect("contiguous normed");
402
403 if let Some(ref mut training) = self.gpu_training {
405 if let Ok(buf) = trainer.upload(normed_slice) {
406 training.lm_head_hidden_buf = buf;
407 }
408 }
409
410 let lm_weight = self.model.lm_head.as_ref().unwrap_or(&self.model.embed_tokens.weight);
412 let lm_data = lm_weight.data();
413 let lm_slice = lm_data.as_slice().expect("contiguous lm_head");
414 let logits = crate::autograd::ops::matmul::matmul_nt_compute(
415 normed_slice,
416 lm_slice,
417 seq_len,
418 hidden_size,
419 vocab_size,
420 );
421 Some(logits)
422 }
423 pub(super) fn forward_logits_gpu_resident(&mut self, token_ids: &[u32]) -> bool {
426 let seq_len = token_ids.len();
427 let vocab_size = self.model.config().vocab_size;
428 let hidden_size = self.model.config().hidden_size;
429
430 if self.gpu_training.is_some() {
431 let (trainer, blocks) = match (&self.cuda_trainer, &mut self.cuda_blocks) {
432 (Some(ref t), Some(ref mut b)) => (t, b),
433 _ => {
434 eprintln!("[RES-FALSE] no trainer/blocks");
435 return false;
436 }
437 };
438 let mut training = self.gpu_training.take();
439 let result = Self::forward_cuda_training(
440 &self.model,
441 token_ids,
442 trainer,
443 blocks,
444 training.as_mut().expect("gpu_training was Some"),
445 &mut self.shared_scratch,
446 );
447 self.gpu_training = training;
448 if result.is_none() {
449 eprintln!("[RES-FALSE] forward_cuda_training returned None");
450 return false;
451 }
452 } else {
453 let (trainer, blocks) = match (&self.cuda_trainer, &mut self.cuda_blocks) {
454 (Some(ref t), Some(ref mut b)) => (t, b),
455 _ => return false,
456 };
457 let normed_hidden = match Self::forward_cuda_inference(
458 &self.model,
459 token_ids,
460 trainer,
461 blocks,
462 &mut self.shared_scratch,
463 ) {
464 Some(h) => h,
465 None => return false,
466 };
467 let training = match self.gpu_training.as_mut() {
468 Some(t) => t,
469 None => return false,
470 };
471 if training.lm_head_hidden_buf.copy_from_host_at(&normed_hidden, 0).is_err() {
472 eprintln!("[CUDA] lm_head forward: hidden upload failed");
473 return false;
474 }
475 }
476
477 let (trainer, training) = match (&self.cuda_trainer, &mut self.gpu_training) {
478 (Some(ref t), Some(ref mut tr)) => (t, tr),
479 _ => {
480 eprintln!("[RES-FALSE] no trainer/training");
481 return false;
482 }
483 };
484
485 let stream = trainer.stream();
486
487 if crate::autograd::cuda_forward::gemm_forward_bt(
488 &training.lm_head_hidden_buf,
489 &training.embed_original,
490 &mut training.logits_buf,
491 seq_len as u32,
492 hidden_size as u32,
493 vocab_size as u32,
494 stream,
495 )
496 .is_err()
497 {
498 eprintln!("[CUDA] lm_head forward GEMM (BT) failed");
499 eprintln!("[RES-FALSE] BT GEMM failed");
500 return false;
501 }
502
503 true
504 }
505}