Skip to main content

entrenar/autograd/cuda_backward/
structured.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::backward::{
10    BatchedRmsNormBackwardKernel, BatchedSoftmaxBackwardKernel, LayerNormBackwardKernel,
11    RmsNormGammaReduceKernel, SoftmaxBackwardKernel,
12};
13#[cfg(feature = "cuda")]
14use trueno_gpu::kernels::BatchedVectorizedRmsNormKernel;
15#[cfg(feature = "cuda")]
16use trueno_gpu::kernels::Kernel;
17
18use super::super::cuda_tensor::{CudaTensorError, Result};
19#[cfg(feature = "cuda")]
20use super::cache::KERNEL_CACHE;
21#[cfg(feature = "cuda")]
22use provable_contracts_macros::requires;
23
24/// Softmax backward pass on GPU
25///
26/// Computes: grad_input = softmax_output * (grad_output - sum(grad_output * softmax_output))
27#[cfg(feature = "cuda")]
28// Contract: backward-pass-v1 / softmax_backward
29#[requires(batch_size > 0 && seq_len > 0)]
30pub fn softmax_backward(
31    softmax_output: &GpuBuffer<f32>,
32    grad_output: &GpuBuffer<f32>,
33    grad_input: &mut GpuBuffer<f32>,
34    batch_size: u32,
35    seq_len: u32,
36    stream: &CudaStream,
37) -> Result<()> {
38    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
39    let mut cache = cache.lock().map_err(|_err| {
40        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
41    })?;
42
43    let key = format!("softmax_backward_{batch_size}_{seq_len}");
44    let module = match cache.get_cached(&key) {
45        Some(m) => m,
46        None => {
47            let kernel = SoftmaxBackwardKernel::new(batch_size, seq_len);
48            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
49            cache.get_or_compile(&key, &ptx)?
50        }
51    };
52
53    // Softmax backward uses warp-parallel reduction
54    let config = LaunchConfig {
55        grid: (batch_size, 1, 1),
56        block: (32.min(seq_len), 1, 1), // Warp size
57        shared_mem: 0,
58    };
59
60    let output_ptr = softmax_output.as_ptr();
61    let grad_out_ptr = grad_output.as_ptr();
62    let grad_in_ptr = grad_input.as_ptr();
63
64    let mut args: [*mut std::ffi::c_void; 5] = [
65        &output_ptr as *const _ as *mut _,
66        &grad_out_ptr as *const _ as *mut _,
67        &grad_in_ptr as *const _ as *mut _,
68        &batch_size as *const _ as *mut _,
69        &seq_len as *const _ as *mut _,
70    ];
71
72    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
73    // matching sizes, and the kernel parameters match the expected PTX signature.
74    unsafe {
75        stream.launch_kernel(module, "softmax_backward", &config, &mut args).map_err(|e| {
76            CudaTensorError::KernelError(format!("Softmax backward launch failed: {e:?}"))
77        })?;
78    }
79
80    Ok(())
81}
82
83/// Batched softmax backward pass on GPU (handles row_size > 32)
84///
85/// Computes: grad_input[r][i] = y[r][i] * (grad_output[r][i] - Σⱼ grad_output[r][j] * y[r][j])
86///
87/// Uses stride-loop + warp-shuffle reduction (one warp per row, one block per row).
88///
89/// # Contract (C-BSMAX-BACK-002)
90///
91/// - **Precondition**: softmax_output contains valid softmax output, all buffers have at least
92///   total_rows * row_size elements, row_size > 0, total_rows > 0, KERNEL_CACHE initialized
93/// - **Postcondition**: grad_input[r][i] = y[r][i] * (∂L/∂y[r][i] - dot(∂L/∂y[r], y[r]))
94/// - **Invariant**: Zero CPU-side data transfers; in-place safe (grad_input may alias grad_output)
95#[cfg(feature = "cuda")]
96pub fn batched_softmax_backward(
97    softmax_output: &GpuBuffer<f32>,
98    grad_output: &GpuBuffer<f32>,
99    grad_input: &mut GpuBuffer<f32>,
100    total_rows: u32,
101    row_size: u32,
102    stream: &CudaStream,
103) -> Result<()> {
104    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
105    let mut cache = cache.lock().map_err(|_err| {
106        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
107    })?;
108
109    // Contract: dimension-independent-kernels-v1.yaml
110    // Note: BatchedSoftmaxBackwardKernel not yet dimension-independent in trueno,
111    // but using generic key prepares for the fix.
112    let key = "batched_softmax_backward";
113    let module = match cache.get_cached(key) {
114        Some(m) => m,
115        None => {
116            let kernel = BatchedSoftmaxBackwardKernel::new(total_rows, row_size);
117            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
118            cache.get_or_compile(key, &ptx)?
119        }
120    };
121
122    // One warp (32 threads) per row, one block per row
123    let config =
124        LaunchConfig { grid: (total_rows, 1, 1), block: (32.min(row_size), 1, 1), shared_mem: 0 };
125
126    let output_ptr = softmax_output.as_ptr();
127    let grad_out_ptr = grad_output.as_ptr();
128    let grad_in_ptr = grad_input.as_ptr();
129
130    let mut args: [*mut std::ffi::c_void; 5] = [
131        &output_ptr as *const _ as *mut _,
132        &grad_out_ptr as *const _ as *mut _,
133        &grad_in_ptr as *const _ as *mut _,
134        &total_rows as *const _ as *mut _,
135        &row_size as *const _ as *mut _,
136    ];
137
138    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
139    // matching sizes, and the kernel parameters match the expected PTX signature.
140    unsafe {
141        stream.launch_kernel(module, "batched_softmax_backward", &config, &mut args).map_err(
142            |e| {
143                CudaTensorError::KernelError(format!(
144                    "Batched softmax backward launch failed: {e:?}"
145                ))
146            },
147        )?;
148    }
149
150    Ok(())
151}
152
153/// RMSNorm backward pass on GPU
154///
155/// Computes gradients for input (and placeholder for gamma parameters).
156/// Uses stride-loop kernel that supports arbitrary hidden_size (no warp-only limit).
157///
158/// # Contract (C-RMSBACK-WRAP-001)
159///
160/// - **Precondition**: input contains original forward input, gamma has hidden_size elements,
161///   all buffers allocated with at least batch_size * hidden_size elements
162/// - **Postcondition**: grad_input contains ∂L/∂x per the RMSNorm backward formula;
163///   `grad_gamma[i]` contains `Σ_r (∂L/∂y[r][i] · x[r][i] / rms[r])` summed in
164///   fixed iteration order over rows (FALSIFY-GPUTRAIN-006).
165/// - **Invariant**: Uses batched stride-loop kernel + deterministic per-row partial
166///   reduction; no hidden_size upper limit; bit-exactly reproducible across two
167///   cuda:0 seed=0 runs (no atomicAdd in the gamma accumulation path).
168#[cfg(feature = "cuda")]
169pub fn rms_norm_backward(
170    input: &GpuBuffer<f32>,
171    gamma: &GpuBuffer<f32>,
172    grad_output: &GpuBuffer<f32>,
173    grad_input: &mut GpuBuffer<f32>,
174    grad_gamma: &mut GpuBuffer<f32>,
175    batch_size: u32,
176    hidden_size: u32,
177    eps: f32,
178    stream: &CudaStream,
179) -> Result<()> {
180    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
181    let mut cache = cache.lock().map_err(|_err| {
182        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
183    })?;
184
185    // FALSIFY-GPUTRAIN-006: allocate per-row partial buffer
186    // `[batch_size × hidden_size]` for the deterministic two-stage reduction. Each
187    // backward block writes EXCLUSIVELY to `grad_gamma_partial[block_idx]`, then the
188    // companion `RmsNormGammaReduceKernel` sums rows in fixed order
189    // (`r = 0, 1, …, batch_size - 1`) into the final `grad_gamma[hidden_size]`.
190    // No atomicAdd is involved — the result is bit-exact across cuda:0 seed=0 reruns.
191    let partial_elem_count = (batch_size as usize) * (hidden_size as usize);
192    let ctx = cache.ctx().clone();
193    let grad_gamma_partial: GpuBuffer<f32> =
194        GpuBuffer::new(&ctx, partial_elem_count).map_err(|e| {
195            CudaTensorError::KernelError(format!(
196                "RMSNorm backward: grad_gamma_partial alloc failed ({batch_size}×{hidden_size}): {e:?}"
197            ))
198        })?;
199
200    // ── Stage 1: per-row partial backward kernel ────────────────────────
201    // Contract: dimension-independent-kernels-v1.yaml (FALSIFY-DIM-001)
202    let key = "batched_rms_norm_backward";
203    let module = match cache.get_cached(key) {
204        Some(m) => m,
205        None => {
206            let kernel = BatchedRmsNormBackwardKernel::new(batch_size, hidden_size, eps);
207            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
208            cache.get_or_compile(key, &ptx)?
209        }
210    };
211
212    // One warp (32 threads) per row, one block per row
213    let config = LaunchConfig {
214        grid: (batch_size, 1, 1),
215        block: (32.min(hidden_size), 1, 1),
216        shared_mem: 0,
217    };
218
219    let input_ptr = input.as_ptr();
220    let gamma_ptr = gamma.as_ptr();
221    let grad_out_ptr = grad_output.as_ptr();
222    let grad_in_ptr = grad_input.as_ptr();
223    // FALSIFY-GPUTRAIN-006: pass the per-row partial buffer (NOT the final
224    // grad_gamma) so the backward kernel writes per-row slots without atomics.
225    let grad_gamma_partial_ptr = grad_gamma_partial.as_ptr();
226
227    let mut args: [*mut std::ffi::c_void; 8] = [
228        &input_ptr as *const _ as *mut _,
229        &gamma_ptr as *const _ as *mut _,
230        &grad_out_ptr as *const _ as *mut _,
231        &grad_in_ptr as *const _ as *mut _,
232        &grad_gamma_partial_ptr as *const _ as *mut _,
233        &batch_size as *const _ as *mut _,
234        &hidden_size as *const _ as *mut _,
235        &eps as *const _ as *mut _,
236    ];
237
238    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
239    // matching sizes, and the kernel parameters match the expected PTX signature.
240    unsafe {
241        stream.launch_kernel(module, "batched_rms_norm_backward", &config, &mut args).map_err(
242            |e| CudaTensorError::KernelError(format!("RMSNorm backward launch failed: {e:?}")),
243        )?;
244    }
245
246    // ── Stage 2: deterministic fixed-order cross-row reduction ──────────
247    let reduce_key = "rms_norm_gamma_reduce";
248    let reduce_module = match cache.get_cached(reduce_key) {
249        Some(m) => m,
250        None => {
251            let kernel = RmsNormGammaReduceKernel::new(batch_size, hidden_size);
252            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
253            cache.get_or_compile(reduce_key, &ptx)?
254        }
255    };
256
257    let reduce_config = LaunchConfig {
258        grid: (hidden_size.div_ceil(RmsNormGammaReduceKernel::BLOCK_SIZE), 1, 1),
259        block: (RmsNormGammaReduceKernel::BLOCK_SIZE, 1, 1),
260        shared_mem: 0,
261    };
262
263    let final_grad_gamma_ptr = grad_gamma.as_ptr();
264
265    let mut reduce_args: [*mut std::ffi::c_void; 4] = [
266        &grad_gamma_partial_ptr as *const _ as *mut _,
267        &final_grad_gamma_ptr as *const _ as *mut _,
268        &batch_size as *const _ as *mut _,
269        &hidden_size as *const _ as *mut _,
270    ];
271
272    // SAFETY: Same FFI invariants as Stage 1. Both buffers are valid GPU
273    // allocations sized batch_size*hidden_size and hidden_size respectively.
274    unsafe {
275        stream
276            .launch_kernel(reduce_module, "rms_norm_gamma_reduce", &reduce_config, &mut reduce_args)
277            .map_err(|e| {
278                CudaTensorError::KernelError(format!("RMSNorm gamma-reduce launch failed: {e:?}"))
279            })?;
280    }
281
282    // grad_gamma_partial drops here; cudaFree is implicit via GpuBuffer Drop.
283    drop(grad_gamma_partial);
284    Ok(())
285}
286
287/// RMSNorm forward pass on GPU (KAIZEN-066).
288///
289/// Computes: output = input * rsqrt(mean(input^2) + eps) * gamma
290///
291/// Uses BatchedVectorizedRmsNormKernel — 8 warps per block, processes
292/// seq_len rows in parallel via Grid.y.
293///
294/// # Contract (C-GPUNORM-001)
295///
296/// - **Precondition**: input has batch_size * hidden_size elements, gamma has hidden_size elements
297/// - **Postcondition**: output contains RMSNorm(input) * gamma
298/// - **Invariant**: Same numerical result as CPU norm.forward_batched (within fp32 precision)
299#[cfg(feature = "cuda")]
300pub fn rms_norm_forward(
301    input: &GpuBuffer<f32>,
302    gamma: &GpuBuffer<f32>,
303    output: &mut GpuBuffer<f32>,
304    batch_size: u32,
305    hidden_size: u32,
306    stream: &CudaStream,
307) -> Result<()> {
308    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
309    let mut cache = cache.lock().map_err(|_err| {
310        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
311    })?;
312
313    let key = format!("batched_rmsnorm_fwd_{hidden_size}");
314    let module = match cache.get_cached(&key) {
315        Some(m) => m,
316        None => {
317            let kernel = BatchedVectorizedRmsNormKernel::new(hidden_size, batch_size);
318            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
319            cache.get_or_compile(&key, &ptx)?
320        }
321    };
322
323    // Grid: (1, batch_size, 1) — one block per row, each block processes one row
324    // Block: (256, 1, 1) — 8 warps per block for parallel reduction
325    let config = LaunchConfig {
326        grid: (1, batch_size, 1),
327        block: (256, 1, 1),
328        shared_mem: 8 * 4, // 8 warp partial sums
329    };
330
331    let input_ptr = input.as_ptr();
332    let output_ptr = output.as_ptr();
333    let gamma_ptr = gamma.as_ptr();
334
335    let mut args: [*mut std::ffi::c_void; 3] = [
336        &input_ptr as *const _ as *mut _,
337        &output_ptr as *const _ as *mut _,
338        &gamma_ptr as *const _ as *mut _,
339    ];
340
341    // SAFETY: Kernel launch requires FFI. input has batch_size * hidden_size elements,
342    // output has batch_size * hidden_size elements, gamma has hidden_size elements.
343    // Parameters match PTX signature (u64 input_ptr, u64 output_ptr, u64 gamma_ptr).
344    unsafe {
345        stream.launch_kernel(module, "batched_rmsnorm_vectorized", &config, &mut args).map_err(
346            |e| CudaTensorError::KernelError(format!("RMSNorm forward launch failed: {e:?}")),
347        )?;
348    }
349
350    Ok(())
351}
352
353/// LayerNorm backward pass on GPU
354///
355/// Computes gradients for input, gamma, and beta parameters
356#[cfg(feature = "cuda")]
357pub fn layer_norm_backward(
358    input: &GpuBuffer<f32>,
359    gamma: &GpuBuffer<f32>,
360    grad_output: &GpuBuffer<f32>,
361    grad_input: &mut GpuBuffer<f32>,
362    grad_gamma: &mut GpuBuffer<f32>,
363    grad_beta: &mut GpuBuffer<f32>,
364    batch_size: u32,
365    hidden_size: u32,
366    stream: &CudaStream,
367) -> Result<()> {
368    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
369    let mut cache = cache.lock().map_err(|_err| {
370        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
371    })?;
372
373    let key = format!("layer_norm_backward_{batch_size}_{hidden_size}");
374    let module = match cache.get_cached(&key) {
375        Some(m) => m,
376        None => {
377            let kernel = LayerNormBackwardKernel::new(batch_size, hidden_size);
378            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
379            cache.get_or_compile(&key, &ptx)?
380        }
381    };
382
383    let config = LaunchConfig {
384        grid: (batch_size, 1, 1),
385        block: (256.min(hidden_size), 1, 1),
386        shared_mem: 0,
387    };
388
389    let input_ptr = input.as_ptr();
390    let gamma_ptr = gamma.as_ptr();
391    let grad_out_ptr = grad_output.as_ptr();
392    let grad_in_ptr = grad_input.as_ptr();
393    let grad_gamma_ptr = grad_gamma.as_ptr();
394    let grad_beta_ptr = grad_beta.as_ptr();
395
396    let mut args: [*mut std::ffi::c_void; 8] = [
397        &input_ptr as *const _ as *mut _,
398        &gamma_ptr as *const _ as *mut _,
399        &grad_out_ptr as *const _ as *mut _,
400        &grad_in_ptr as *const _ as *mut _,
401        &grad_gamma_ptr as *const _ as *mut _,
402        &grad_beta_ptr as *const _ as *mut _,
403        &batch_size as *const _ as *mut _,
404        &hidden_size as *const _ as *mut _,
405    ];
406
407    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
408    // matching sizes, and the kernel parameters match the expected PTX signature.
409    unsafe {
410        stream.launch_kernel(module, "layer_norm_backward", &config, &mut args).map_err(|e| {
411            CudaTensorError::KernelError(format!("LayerNorm backward launch failed: {e:?}"))
412        })?;
413    }
414
415    Ok(())
416}