entrenar/autograd/cuda_backward/
structured.rs1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::backward::{
10 BatchedRmsNormBackwardKernel, BatchedSoftmaxBackwardKernel, LayerNormBackwardKernel,
11 RmsNormGammaReduceKernel, SoftmaxBackwardKernel,
12};
13#[cfg(feature = "cuda")]
14use trueno_gpu::kernels::BatchedVectorizedRmsNormKernel;
15#[cfg(feature = "cuda")]
16use trueno_gpu::kernels::Kernel;
17
18use super::super::cuda_tensor::{CudaTensorError, Result};
19#[cfg(feature = "cuda")]
20use super::cache::KERNEL_CACHE;
21#[cfg(feature = "cuda")]
22use provable_contracts_macros::requires;
23
24#[cfg(feature = "cuda")]
28#[requires(batch_size > 0 && seq_len > 0)]
30pub fn softmax_backward(
31 softmax_output: &GpuBuffer<f32>,
32 grad_output: &GpuBuffer<f32>,
33 grad_input: &mut GpuBuffer<f32>,
34 batch_size: u32,
35 seq_len: u32,
36 stream: &CudaStream,
37) -> Result<()> {
38 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
39 let mut cache = cache.lock().map_err(|_err| {
40 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
41 })?;
42
43 let key = format!("softmax_backward_{batch_size}_{seq_len}");
44 let module = match cache.get_cached(&key) {
45 Some(m) => m,
46 None => {
47 let kernel = SoftmaxBackwardKernel::new(batch_size, seq_len);
48 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
49 cache.get_or_compile(&key, &ptx)?
50 }
51 };
52
53 let config = LaunchConfig {
55 grid: (batch_size, 1, 1),
56 block: (32.min(seq_len), 1, 1), shared_mem: 0,
58 };
59
60 let output_ptr = softmax_output.as_ptr();
61 let grad_out_ptr = grad_output.as_ptr();
62 let grad_in_ptr = grad_input.as_ptr();
63
64 let mut args: [*mut std::ffi::c_void; 5] = [
65 &output_ptr as *const _ as *mut _,
66 &grad_out_ptr as *const _ as *mut _,
67 &grad_in_ptr as *const _ as *mut _,
68 &batch_size as *const _ as *mut _,
69 &seq_len as *const _ as *mut _,
70 ];
71
72 unsafe {
75 stream.launch_kernel(module, "softmax_backward", &config, &mut args).map_err(|e| {
76 CudaTensorError::KernelError(format!("Softmax backward launch failed: {e:?}"))
77 })?;
78 }
79
80 Ok(())
81}
82
83#[cfg(feature = "cuda")]
96pub fn batched_softmax_backward(
97 softmax_output: &GpuBuffer<f32>,
98 grad_output: &GpuBuffer<f32>,
99 grad_input: &mut GpuBuffer<f32>,
100 total_rows: u32,
101 row_size: u32,
102 stream: &CudaStream,
103) -> Result<()> {
104 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
105 let mut cache = cache.lock().map_err(|_err| {
106 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
107 })?;
108
109 let key = "batched_softmax_backward";
113 let module = match cache.get_cached(key) {
114 Some(m) => m,
115 None => {
116 let kernel = BatchedSoftmaxBackwardKernel::new(total_rows, row_size);
117 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
118 cache.get_or_compile(key, &ptx)?
119 }
120 };
121
122 let config =
124 LaunchConfig { grid: (total_rows, 1, 1), block: (32.min(row_size), 1, 1), shared_mem: 0 };
125
126 let output_ptr = softmax_output.as_ptr();
127 let grad_out_ptr = grad_output.as_ptr();
128 let grad_in_ptr = grad_input.as_ptr();
129
130 let mut args: [*mut std::ffi::c_void; 5] = [
131 &output_ptr as *const _ as *mut _,
132 &grad_out_ptr as *const _ as *mut _,
133 &grad_in_ptr as *const _ as *mut _,
134 &total_rows as *const _ as *mut _,
135 &row_size as *const _ as *mut _,
136 ];
137
138 unsafe {
141 stream.launch_kernel(module, "batched_softmax_backward", &config, &mut args).map_err(
142 |e| {
143 CudaTensorError::KernelError(format!(
144 "Batched softmax backward launch failed: {e:?}"
145 ))
146 },
147 )?;
148 }
149
150 Ok(())
151}
152
153#[cfg(feature = "cuda")]
169pub fn rms_norm_backward(
170 input: &GpuBuffer<f32>,
171 gamma: &GpuBuffer<f32>,
172 grad_output: &GpuBuffer<f32>,
173 grad_input: &mut GpuBuffer<f32>,
174 grad_gamma: &mut GpuBuffer<f32>,
175 batch_size: u32,
176 hidden_size: u32,
177 eps: f32,
178 stream: &CudaStream,
179) -> Result<()> {
180 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
181 let mut cache = cache.lock().map_err(|_err| {
182 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
183 })?;
184
185 let partial_elem_count = (batch_size as usize) * (hidden_size as usize);
192 let ctx = cache.ctx().clone();
193 let grad_gamma_partial: GpuBuffer<f32> =
194 GpuBuffer::new(&ctx, partial_elem_count).map_err(|e| {
195 CudaTensorError::KernelError(format!(
196 "RMSNorm backward: grad_gamma_partial alloc failed ({batch_size}×{hidden_size}): {e:?}"
197 ))
198 })?;
199
200 let key = "batched_rms_norm_backward";
203 let module = match cache.get_cached(key) {
204 Some(m) => m,
205 None => {
206 let kernel = BatchedRmsNormBackwardKernel::new(batch_size, hidden_size, eps);
207 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
208 cache.get_or_compile(key, &ptx)?
209 }
210 };
211
212 let config = LaunchConfig {
214 grid: (batch_size, 1, 1),
215 block: (32.min(hidden_size), 1, 1),
216 shared_mem: 0,
217 };
218
219 let input_ptr = input.as_ptr();
220 let gamma_ptr = gamma.as_ptr();
221 let grad_out_ptr = grad_output.as_ptr();
222 let grad_in_ptr = grad_input.as_ptr();
223 let grad_gamma_partial_ptr = grad_gamma_partial.as_ptr();
226
227 let mut args: [*mut std::ffi::c_void; 8] = [
228 &input_ptr as *const _ as *mut _,
229 &gamma_ptr as *const _ as *mut _,
230 &grad_out_ptr as *const _ as *mut _,
231 &grad_in_ptr as *const _ as *mut _,
232 &grad_gamma_partial_ptr as *const _ as *mut _,
233 &batch_size as *const _ as *mut _,
234 &hidden_size as *const _ as *mut _,
235 &eps as *const _ as *mut _,
236 ];
237
238 unsafe {
241 stream.launch_kernel(module, "batched_rms_norm_backward", &config, &mut args).map_err(
242 |e| CudaTensorError::KernelError(format!("RMSNorm backward launch failed: {e:?}")),
243 )?;
244 }
245
246 let reduce_key = "rms_norm_gamma_reduce";
248 let reduce_module = match cache.get_cached(reduce_key) {
249 Some(m) => m,
250 None => {
251 let kernel = RmsNormGammaReduceKernel::new(batch_size, hidden_size);
252 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
253 cache.get_or_compile(reduce_key, &ptx)?
254 }
255 };
256
257 let reduce_config = LaunchConfig {
258 grid: (hidden_size.div_ceil(RmsNormGammaReduceKernel::BLOCK_SIZE), 1, 1),
259 block: (RmsNormGammaReduceKernel::BLOCK_SIZE, 1, 1),
260 shared_mem: 0,
261 };
262
263 let final_grad_gamma_ptr = grad_gamma.as_ptr();
264
265 let mut reduce_args: [*mut std::ffi::c_void; 4] = [
266 &grad_gamma_partial_ptr as *const _ as *mut _,
267 &final_grad_gamma_ptr as *const _ as *mut _,
268 &batch_size as *const _ as *mut _,
269 &hidden_size as *const _ as *mut _,
270 ];
271
272 unsafe {
275 stream
276 .launch_kernel(reduce_module, "rms_norm_gamma_reduce", &reduce_config, &mut reduce_args)
277 .map_err(|e| {
278 CudaTensorError::KernelError(format!("RMSNorm gamma-reduce launch failed: {e:?}"))
279 })?;
280 }
281
282 drop(grad_gamma_partial);
284 Ok(())
285}
286
287#[cfg(feature = "cuda")]
300pub fn rms_norm_forward(
301 input: &GpuBuffer<f32>,
302 gamma: &GpuBuffer<f32>,
303 output: &mut GpuBuffer<f32>,
304 batch_size: u32,
305 hidden_size: u32,
306 stream: &CudaStream,
307) -> Result<()> {
308 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
309 let mut cache = cache.lock().map_err(|_err| {
310 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
311 })?;
312
313 let key = format!("batched_rmsnorm_fwd_{hidden_size}");
314 let module = match cache.get_cached(&key) {
315 Some(m) => m,
316 None => {
317 let kernel = BatchedVectorizedRmsNormKernel::new(hidden_size, batch_size);
318 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
319 cache.get_or_compile(&key, &ptx)?
320 }
321 };
322
323 let config = LaunchConfig {
326 grid: (1, batch_size, 1),
327 block: (256, 1, 1),
328 shared_mem: 8 * 4, };
330
331 let input_ptr = input.as_ptr();
332 let output_ptr = output.as_ptr();
333 let gamma_ptr = gamma.as_ptr();
334
335 let mut args: [*mut std::ffi::c_void; 3] = [
336 &input_ptr as *const _ as *mut _,
337 &output_ptr as *const _ as *mut _,
338 &gamma_ptr as *const _ as *mut _,
339 ];
340
341 unsafe {
345 stream.launch_kernel(module, "batched_rmsnorm_vectorized", &config, &mut args).map_err(
346 |e| CudaTensorError::KernelError(format!("RMSNorm forward launch failed: {e:?}")),
347 )?;
348 }
349
350 Ok(())
351}
352
353#[cfg(feature = "cuda")]
357pub fn layer_norm_backward(
358 input: &GpuBuffer<f32>,
359 gamma: &GpuBuffer<f32>,
360 grad_output: &GpuBuffer<f32>,
361 grad_input: &mut GpuBuffer<f32>,
362 grad_gamma: &mut GpuBuffer<f32>,
363 grad_beta: &mut GpuBuffer<f32>,
364 batch_size: u32,
365 hidden_size: u32,
366 stream: &CudaStream,
367) -> Result<()> {
368 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
369 let mut cache = cache.lock().map_err(|_err| {
370 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
371 })?;
372
373 let key = format!("layer_norm_backward_{batch_size}_{hidden_size}");
374 let module = match cache.get_cached(&key) {
375 Some(m) => m,
376 None => {
377 let kernel = LayerNormBackwardKernel::new(batch_size, hidden_size);
378 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
379 cache.get_or_compile(&key, &ptx)?
380 }
381 };
382
383 let config = LaunchConfig {
384 grid: (batch_size, 1, 1),
385 block: (256.min(hidden_size), 1, 1),
386 shared_mem: 0,
387 };
388
389 let input_ptr = input.as_ptr();
390 let gamma_ptr = gamma.as_ptr();
391 let grad_out_ptr = grad_output.as_ptr();
392 let grad_in_ptr = grad_input.as_ptr();
393 let grad_gamma_ptr = grad_gamma.as_ptr();
394 let grad_beta_ptr = grad_beta.as_ptr();
395
396 let mut args: [*mut std::ffi::c_void; 8] = [
397 &input_ptr as *const _ as *mut _,
398 &gamma_ptr as *const _ as *mut _,
399 &grad_out_ptr as *const _ as *mut _,
400 &grad_in_ptr as *const _ as *mut _,
401 &grad_gamma_ptr as *const _ as *mut _,
402 &grad_beta_ptr as *const _ as *mut _,
403 &batch_size as *const _ as *mut _,
404 &hidden_size as *const _ as *mut _,
405 ];
406
407 unsafe {
410 stream.launch_kernel(module, "layer_norm_backward", &config, &mut args).map_err(|e| {
411 CudaTensorError::KernelError(format!("LayerNorm backward launch failed: {e:?}"))
412 })?;
413 }
414
415 Ok(())
416}