entrenar/finetune/instruct_pipeline/
backward.rs1#[allow(clippy::wildcard_imports)]
5use super::*;
6
7#[cfg(feature = "cuda")]
8use trueno_gpu::driver::GpuBuffer;
9
10impl InstructPipeline {
11 #[cfg(feature = "cuda")]
17 #[allow(unsafe_code)]
18 pub(super) fn backward_nf4_gpu_blocks(
19 &mut self,
20 grad_final_hidden: &[f32],
21 seq_len: usize,
22 ) -> Option<()> {
23 let hidden_size = self.model.config.hidden_size;
24
25 {
28 let trainer = self.cuda_trainer.as_ref()?;
29 let training_state = self.gpu_training.as_mut()?;
30 let stream = trainer.stream();
31
32 training_state.grad_upload_buf = trainer.upload(grad_final_hidden).ok()?;
34
35 let expected_len = seq_len * hidden_size;
37 if training_state.grad_buf_a.len() != expected_len {
38 training_state.grad_buf_a = trainer.zeros(expected_len).ok()?;
39 training_state.grad_buf_b = trainer.zeros(expected_len).ok()?;
40 training_state.output_scratch = trainer.zeros(expected_len).ok()?;
41 training_state.grad_upload_buf = trainer.upload(grad_final_hidden).ok()?;
42 }
43
44 crate::autograd::cuda_backward::rms_norm_backward(
46 &training_state.blocks_output,
47 &training_state.final_norm_weight,
48 &training_state.grad_upload_buf,
49 &mut training_state.grad_buf_a,
50 &mut training_state.grad_final_norm_weight,
51 seq_len as u32,
52 hidden_size as u32,
53 1e-5_f32,
54 stream,
55 )
56 .ok()?;
57 }
58
59 self.backward_nf4_gpu_blocks_loop(seq_len)
60 }
61
62 #[cfg(feature = "cuda")]
71 #[allow(unsafe_code)]
72 pub(super) fn backward_nf4_gpu_blocks_gpu_resident(&mut self, seq_len: usize) -> Option<()> {
73 let hidden_size = self.model.config.hidden_size;
74
75 {
77 let trainer = self.cuda_trainer.as_ref()?;
78 let training_state = self.gpu_training.as_mut()?;
79 let stream = trainer.stream();
80
81 crate::autograd::cuda_backward::rms_norm_backward(
82 &training_state.blocks_output,
83 &training_state.final_norm_weight,
84 &training_state.grad_hidden_buf,
85 &mut training_state.grad_buf_a,
86 &mut training_state.grad_final_norm_weight,
87 seq_len as u32,
88 hidden_size as u32,
89 1e-5_f32,
90 stream,
91 )
92 .ok()?;
93 }
94
95 let result = self.backward_nf4_gpu_blocks_loop(seq_len);
96 if let Some(ref mut scratch) = self.shared_scratch {
99 scratch.causal_mask_cached_seq_len = 0;
100 }
101 result
102 }
103
104 #[cfg(feature = "cuda")]
115 #[allow(unsafe_code)]
116 fn backward_nf4_gpu_blocks_loop(&mut self, seq_len: usize) -> Option<()> {
117 let trainer = self.cuda_trainer.as_ref()?;
118 let stream = trainer.stream();
119
120 {
122 let training_state = self.gpu_training.as_ref()?;
123 if super::super::backward_graph::use_backward_graph() {
124 if let Some(ref state) = training_state.backward_graph_state {
125 if state.cached_seq_len == seq_len {
126 super::super::backward_graph::replay_backward(state, stream)?;
128 self.nf4_lora_step += 1;
130 stream.synchronize().ok()?;
131 return Some(());
132 }
133 }
134 }
135 }
136
137 let use_graph = super::super::backward_graph::use_backward_graph();
139
140 let lr = self.optimizer.lr();
141 let training_state = self.gpu_training.as_mut()?;
142 let blocks = self.cuda_blocks.as_mut()?;
143 let shared_scratch = self.shared_scratch.as_mut()?;
144 let grad_lora = self.cuda_lora_grad_workspace.as_mut()?;
145 let opt_states = self.cuda_lora_optimizer_states.as_mut()?;
146
147 let num_layers = blocks.len();
148 let grad_a_ptr: *mut GpuBuffer<f32> = std::ptr::from_mut(&mut training_state.grad_buf_a);
149 let grad_b_ptr: *mut GpuBuffer<f32> = std::ptr::from_mut(&mut training_state.grad_buf_b);
150 let mut grad_output_is_a = true;
151
152 self.nf4_lora_step += 1;
153 let step = self.nf4_lora_step;
154
155 let output_scratch_ptr: *mut GpuBuffer<f32> =
156 std::ptr::from_mut(&mut training_state.output_scratch);
157
158 if use_graph {
160 if let Err(e) = stream.begin_capture(trueno_gpu::driver::CaptureMode::ThreadLocal) {
161 eprintln!(
162 "[CUDA] Backward graph capture begin failed: {e} — falling back to ungraphed"
163 );
164 }
165 }
166
167 for layer_idx in (0..num_layers).rev() {
168 let (grad_output, grad_input) = unsafe {
169 if grad_output_is_a {
170 (&*grad_a_ptr, &mut *grad_b_ptr)
171 } else {
172 (&*grad_b_ptr, &mut *grad_a_ptr)
173 }
174 };
175
176 let layer_bwd_start = if !use_graph { Some(std::time::Instant::now()) } else { None };
178
179 blocks[layer_idx]
180 .backward_nf4(
181 &training_state.layer_inputs[layer_idx],
182 grad_output,
183 grad_input,
184 unsafe { &mut *output_scratch_ptr },
185 seq_len,
186 stream,
187 shared_scratch,
188 grad_lora,
189 )
190 .ok()?;
191
192 if let Some(max_norm) = self.config.gradient_clip_norm {
194 if let Some(ref clip_state) = self.lora_fused_clip {
195 super::super::fused_lora_clip::clip_lora_gradients_fused(
196 grad_lora, max_norm, clip_state, stream,
197 );
198 } else if !use_graph {
199 grad_lora.clip_gradients(max_norm, stream);
201 }
202 }
203
204 blocks[layer_idx]
205 .lora_optimizer_step(
206 &mut opt_states[layer_idx],
207 step,
208 lr,
209 0.9,
210 0.999,
211 1e-8,
212 0.01,
213 stream,
214 grad_lora,
215 )
216 .ok()?;
217
218 if let Some(start) = layer_bwd_start {
220 if layer_idx < training_state.profiler_layer_bwd_us.len() {
221 training_state.profiler_layer_bwd_us[layer_idx] =
222 start.elapsed().as_micros() as u64;
223 }
224 }
225
226 grad_output_is_a = !grad_output_is_a;
227 }
228
229 if use_graph {
231 match stream.end_capture() {
232 Ok(graph) => match graph.instantiate() {
233 Ok(exec) => {
234 eprintln!(
235 "[CUDA] Backward graph captured: seq_len={seq_len}, {num_layers} layers"
236 );
237 training_state.backward_graph_state =
238 Some(super::super::backward_graph::BackwardGraphState {
239 exec,
240 cached_seq_len: seq_len,
241 });
242 }
243 Err(e) => eprintln!("[CUDA] Backward graph instantiate failed: {e}"),
244 },
245 Err(e) => eprintln!("[CUDA] Backward graph end_capture failed: {e}"),
246 }
247 }
248
249 stream.synchronize().ok()?;
250
251 Some(())
252 }
253}