Skip to main content

entrenar/finetune/instruct_pipeline/
backward.rs

1//! NF4 QLoRA backward pass: `backward_nf4_gpu_blocks`,
2//! `backward_nf4_gpu_blocks_gpu_resident`, `backward_nf4_gpu_blocks_loop`.
3
4#[allow(clippy::wildcard_imports)]
5use super::*;
6
7#[cfg(feature = "cuda")]
8use trueno_gpu::driver::GpuBuffer;
9
10impl InstructPipeline {
11    /// NF4 QLoRA backward pass through all GPU transformer blocks.
12    ///
13    /// Computes gradient flow through frozen NF4 weights and updates LoRA
14    /// adapters. After each block backward, immediately runs the LoRA optimizer
15    /// step (grad workspace is shared across layers).
16    #[cfg(feature = "cuda")]
17    #[allow(unsafe_code)]
18    pub(super) fn backward_nf4_gpu_blocks(
19        &mut self,
20        grad_final_hidden: &[f32],
21        seq_len: usize,
22    ) -> Option<()> {
23        let hidden_size = self.model.config.hidden_size;
24
25        // Upload gradient and run RMSNorm backward in a scope to release borrows
26        // before calling the shared block-loop.
27        {
28            let trainer = self.cuda_trainer.as_ref()?;
29            let training_state = self.gpu_training.as_mut()?;
30            let stream = trainer.stream();
31
32            // PMAT-420: Use trainer.upload (fresh alloc) instead of copy_from_host_at
33            training_state.grad_upload_buf = trainer.upload(grad_final_hidden).ok()?;
34
35            // PMAT-420: Re-allocate grad buffers at seq_len if forward re-sized layer_inputs
36            let expected_len = seq_len * hidden_size;
37            if training_state.grad_buf_a.len() != expected_len {
38                training_state.grad_buf_a = trainer.zeros(expected_len).ok()?;
39                training_state.grad_buf_b = trainer.zeros(expected_len).ok()?;
40                training_state.output_scratch = trainer.zeros(expected_len).ok()?;
41                training_state.grad_upload_buf = trainer.upload(grad_final_hidden).ok()?;
42            }
43
44            // RMSNorm backward on GPU
45            crate::autograd::cuda_backward::rms_norm_backward(
46                &training_state.blocks_output,
47                &training_state.final_norm_weight,
48                &training_state.grad_upload_buf,
49                &mut training_state.grad_buf_a,
50                &mut training_state.grad_final_norm_weight,
51                seq_len as u32,
52                hidden_size as u32,
53                1e-5_f32,
54                stream,
55            )
56            .ok()?;
57        }
58
59        self.backward_nf4_gpu_blocks_loop(seq_len)
60    }
61
62    /// GPU-resident backward: gradient already in grad_hidden_buf from GEMM (KAIZEN-065).
63    ///
64    /// Same as backward_nf4_gpu_blocks but reads gradient directly from
65    /// grad_hidden_buf instead of uploading from CPU. Eliminates:
66    /// - ~5MB D2H download (grad_hidden_buf -> CPU)
67    /// - ~5MB H2D upload (CPU -> grad_upload_buf)
68    /// - 1x stream.synchronize() GPU drain point
69    /// - 1x Vec<f32> heap allocation (~5MB)
70    #[cfg(feature = "cuda")]
71    #[allow(unsafe_code)]
72    pub(super) fn backward_nf4_gpu_blocks_gpu_resident(&mut self, seq_len: usize) -> Option<()> {
73        let hidden_size = self.model.config.hidden_size;
74
75        // KAIZEN-065: grad_hidden_buf already contains the gradient from lm_head backward GEMM.
76        {
77            let trainer = self.cuda_trainer.as_ref()?;
78            let training_state = self.gpu_training.as_mut()?;
79            let stream = trainer.stream();
80
81            crate::autograd::cuda_backward::rms_norm_backward(
82                &training_state.blocks_output,
83                &training_state.final_norm_weight,
84                &training_state.grad_hidden_buf,
85                &mut training_state.grad_buf_a,
86                &mut training_state.grad_final_norm_weight,
87                seq_len as u32,
88                hidden_size as u32,
89                1e-5_f32,
90                stream,
91            )
92            .ok()?;
93        }
94
95        let result = self.backward_nf4_gpu_blocks_loop(seq_len);
96        // entrenar#318: invalidate shared scratch causal mask after backward
97        // to prevent gradient contamination on next forward (backward writes to shared scratch)
98        if let Some(ref mut scratch) = self.shared_scratch {
99            scratch.causal_mask_cached_seq_len = 0;
100        }
101        result
102    }
103
104    /// Shared backward loop for NF4 blocks -- called by both CPU-upload and
105    /// GPU-resident backward paths after RMSNorm backward completes.
106    ///
107    /// PMAT-488: When CUDA_GRAPH=1, captures the entire backward loop into a
108    /// CUDA graph on first call and replays it on subsequent calls. This
109    /// eliminates 84.6% kernel launch overhead (6.5x throughput improvement).
110    ///
111    /// The graph captures: backward_nf4 + fused_clip + optimizer_step for all
112    /// 28 layers. All operations are async GPU kernels with zero host-device
113    /// sync (PMAT-477 fused clip).
114    #[cfg(feature = "cuda")]
115    #[allow(unsafe_code)]
116    fn backward_nf4_gpu_blocks_loop(&mut self, seq_len: usize) -> Option<()> {
117        let trainer = self.cuda_trainer.as_ref()?;
118        let stream = trainer.stream();
119
120        // PMAT-488: Check for backward graph replay
121        {
122            let training_state = self.gpu_training.as_ref()?;
123            if super::super::backward_graph::use_backward_graph() {
124                if let Some(ref state) = training_state.backward_graph_state {
125                    if state.cached_seq_len == seq_len {
126                        // Replay cached backward graph — all 28 layers in one launch
127                        super::super::backward_graph::replay_backward(state, stream)?;
128                        // Still need to increment step counter for LR scheduling
129                        self.nf4_lora_step += 1;
130                        stream.synchronize().ok()?;
131                        return Some(());
132                    }
133                }
134            }
135        }
136
137        // Either graph not enabled, seq_len changed, or first capture needed
138        let use_graph = super::super::backward_graph::use_backward_graph();
139
140        let lr = self.optimizer.lr();
141        let training_state = self.gpu_training.as_mut()?;
142        let blocks = self.cuda_blocks.as_mut()?;
143        let shared_scratch = self.shared_scratch.as_mut()?;
144        let grad_lora = self.cuda_lora_grad_workspace.as_mut()?;
145        let opt_states = self.cuda_lora_optimizer_states.as_mut()?;
146
147        let num_layers = blocks.len();
148        let grad_a_ptr: *mut GpuBuffer<f32> = std::ptr::from_mut(&mut training_state.grad_buf_a);
149        let grad_b_ptr: *mut GpuBuffer<f32> = std::ptr::from_mut(&mut training_state.grad_buf_b);
150        let mut grad_output_is_a = true;
151
152        self.nf4_lora_step += 1;
153        let step = self.nf4_lora_step;
154
155        let output_scratch_ptr: *mut GpuBuffer<f32> =
156            std::ptr::from_mut(&mut training_state.output_scratch);
157
158        // PMAT-488: Begin graph capture if enabled
159        if use_graph {
160            if let Err(e) = stream.begin_capture(trueno_gpu::driver::CaptureMode::ThreadLocal) {
161                eprintln!(
162                    "[CUDA] Backward graph capture begin failed: {e} — falling back to ungraphed"
163                );
164            }
165        }
166
167        for layer_idx in (0..num_layers).rev() {
168            let (grad_output, grad_input) = unsafe {
169                if grad_output_is_a {
170                    (&*grad_a_ptr, &mut *grad_b_ptr)
171                } else {
172                    (&*grad_b_ptr, &mut *grad_a_ptr)
173                }
174            };
175
176            // PMAT-483: Per-layer backward profiling (only when NOT in graph capture)
177            let layer_bwd_start = if !use_graph { Some(std::time::Instant::now()) } else { None };
178
179            blocks[layer_idx]
180                .backward_nf4(
181                    &training_state.layer_inputs[layer_idx],
182                    grad_output,
183                    grad_input,
184                    unsafe { &mut *output_scratch_ptr },
185                    seq_len,
186                    stream,
187                    shared_scratch,
188                    grad_lora,
189                )
190                .ok()?;
191
192            // PMAT-477: Fused clip (zero D2H sync) — graph-capture compatible
193            if let Some(max_norm) = self.config.gradient_clip_norm {
194                if let Some(ref clip_state) = self.lora_fused_clip {
195                    super::super::fused_lora_clip::clip_lora_gradients_fused(
196                        grad_lora, max_norm, clip_state, stream,
197                    );
198                } else if !use_graph {
199                    // Sync fallback only when NOT in graph capture (sync breaks capture)
200                    grad_lora.clip_gradients(max_norm, stream);
201                }
202            }
203
204            blocks[layer_idx]
205                .lora_optimizer_step(
206                    &mut opt_states[layer_idx],
207                    step,
208                    lr,
209                    0.9,
210                    0.999,
211                    1e-8,
212                    0.01,
213                    stream,
214                    grad_lora,
215                )
216                .ok()?;
217
218            // PMAT-483: Record per-layer backward time (skip during graph capture)
219            if let Some(start) = layer_bwd_start {
220                if layer_idx < training_state.profiler_layer_bwd_us.len() {
221                    training_state.profiler_layer_bwd_us[layer_idx] =
222                        start.elapsed().as_micros() as u64;
223                }
224            }
225
226            grad_output_is_a = !grad_output_is_a;
227        }
228
229        // PMAT-488: End graph capture and cache
230        if use_graph {
231            match stream.end_capture() {
232                Ok(graph) => match graph.instantiate() {
233                    Ok(exec) => {
234                        eprintln!(
235                            "[CUDA] Backward graph captured: seq_len={seq_len}, {num_layers} layers"
236                        );
237                        training_state.backward_graph_state =
238                            Some(super::super::backward_graph::BackwardGraphState {
239                                exec,
240                                cached_seq_len: seq_len,
241                            });
242                    }
243                    Err(e) => eprintln!("[CUDA] Backward graph instantiate failed: {e}"),
244                },
245                Err(e) => eprintln!("[CUDA] Backward graph end_capture failed: {e}"),
246            }
247        }
248
249        stream.synchronize().ok()?;
250
251        Some(())
252    }
253}