Skip to main content

entrenar/autograd/cuda_backward/
gemm.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::backward::{GemmBackwardAKernel, GemmBackwardBKernel};
10#[cfg(feature = "cuda")]
11use trueno_gpu::kernels::Kernel;
12
13use super::super::cuda_tensor::{CudaTensorError, Result};
14#[cfg(feature = "cuda")]
15use super::cache::KERNEL_CACHE;
16
17// cuBLAS backward dispatch (ALB-075)
18#[cfg(feature = "cuda")]
19use crate::autograd::cuda_forward::{cublas_gemm_backward_a, cublas_gemm_backward_b};
20
21/// Tile size for backward GEMM kernels (C-TILE-BWD-001).
22///
23/// Must be divisible by 4 (unroll factor). Shared memory per block = 2 * TILE^2 * 4 bytes.
24/// TILE=16: 2KB smem, 256 threads/block. Safe for all dimensions including LoRA rank=16.
25const BACKWARD_TILE_SIZE: u32 = 16;
26
27/// GEMM backward pass for matrix A on GPU (trueno#109: tiled)
28///
29/// Given C = A @ B, computes: grad_A = grad_C @ B^T
30///
31/// Uses tiled GEMM with shared memory (C-TILE-BWD-001) and 4x unrolled inner loop.
32#[cfg(feature = "cuda")]
33pub fn gemm_backward_a(
34    grad_output: &GpuBuffer<f32>,
35    b: &GpuBuffer<f32>,
36    grad_a: &mut GpuBuffer<f32>,
37    m: u32,
38    k: u32,
39    n: u32,
40    stream: &CudaStream,
41) -> Result<()> {
42    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
43    let mut cache = cache.lock().map_err(|_err| {
44        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
45    })?;
46
47    // ALB-075: cuBLAS SIMD fast path (6-14x faster than PTX)
48    // ALB-076: Uses CUBLAS_DEFAULT_MATH (no tensor cores) for backward GEMMs.
49    // trueno#170 fixed NaN corruption caused by tensor core algorithms (TF32)
50    // on transposed GEMMs with gradient magnitudes ~1e5. Forward GEMMs remain
51    // on tensor cores since NoTrans/NoTrans is unaffected.
52    if let Some(cublas) = cache.cublas() {
53        return cublas_gemm_backward_a(cublas, grad_output, b, grad_a, m, k, n);
54    }
55
56    let tile = BACKWARD_TILE_SIZE;
57    // Kernel object needed for name(); cheap struct creation, PTX deferred.
58    let kernel = GemmBackwardAKernel::tiled_unrolled(m, n, k, tile);
59    let kernel_name = kernel.name();
60
61    let key = format!("gemm_backward_a_{m}_{k}_{n}");
62    let module = match cache.get_cached(&key) {
63        Some(m) => m,
64        None => {
65            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
66            cache.get_or_compile(&key, &ptx)?
67        }
68    };
69
70    // Tiled launch: block = (TILE, TILE), grid covers output grad_a[M, K]
71    let smem = 2 * tile * tile * 4; // 2 tiles of f32
72    let config = LaunchConfig {
73        grid: (k.div_ceil(tile), m.div_ceil(tile), 1),
74        block: (tile, tile, 1),
75        shared_mem: smem,
76    };
77
78    let grad_out_ptr = grad_output.as_ptr();
79    let b_ptr = b.as_ptr();
80    let grad_a_ptr = grad_a.as_ptr();
81
82    // PTX kernel signature: (grad_c_ptr, b_ptr, grad_a_ptr, m, n, k)
83    // CRITICAL: must match param declaration order in GemmBackwardAKernel::build_ptx()
84    let mut args: [*mut std::ffi::c_void; 6] = [
85        &grad_out_ptr as *const _ as *mut _,
86        &b_ptr as *const _ as *mut _,
87        &grad_a_ptr as *const _ as *mut _,
88        &m as *const _ as *mut _,
89        &n as *const _ as *mut _,
90        &k as *const _ as *mut _,
91    ];
92
93    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
94    // matching sizes, and the kernel parameters match the expected PTX signature.
95    unsafe {
96        stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
97            CudaTensorError::KernelError(format!("GEMM backward A launch failed: {e:?}"))
98        })?;
99    }
100
101    Ok(())
102}
103
104/// GEMM backward pass for matrix B on GPU (trueno#109: tiled)
105///
106/// Given C = A @ B, computes: grad_B = A^T @ grad_C
107///
108/// Uses tiled GEMM with shared memory (C-TILE-BWD-002) and 4x unrolled inner loop.
109#[cfg(feature = "cuda")]
110pub fn gemm_backward_b(
111    a: &GpuBuffer<f32>,
112    grad_output: &GpuBuffer<f32>,
113    grad_b: &mut GpuBuffer<f32>,
114    m: u32,
115    k: u32,
116    n: u32,
117    stream: &CudaStream,
118) -> Result<()> {
119    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
120    let mut cache = cache.lock().map_err(|_err| {
121        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
122    })?;
123
124    // ALB-075: cuBLAS SIMD fast path (6-14x faster than PTX)
125    // ALB-076: Uses CUBLAS_DEFAULT_MATH (no tensor cores) for backward GEMMs.
126    // trueno#170 fixed NaN corruption caused by tensor core algorithms (TF32)
127    // on transposed GEMMs with gradient magnitudes ~1e5. Forward GEMMs remain
128    // on tensor cores since NoTrans/NoTrans is unaffected.
129    if let Some(cublas) = cache.cublas() {
130        return cublas_gemm_backward_b(cublas, a, grad_output, grad_b, m, k, n);
131    }
132
133    let tile = BACKWARD_TILE_SIZE;
134    // Kernel object needed for name(); cheap struct creation, PTX deferred.
135    let kernel = GemmBackwardBKernel::tiled_unrolled(m, n, k, tile);
136    let kernel_name = kernel.name();
137
138    let key = format!("gemm_backward_b_{m}_{k}_{n}");
139    let module = match cache.get_cached(&key) {
140        Some(m) => m,
141        None => {
142            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
143            cache.get_or_compile(&key, &ptx)?
144        }
145    };
146
147    // Tiled launch: block = (TILE, TILE), grid covers output grad_b[K, N]
148    let smem = 2 * tile * tile * 4;
149    let config = LaunchConfig {
150        grid: (n.div_ceil(tile), k.div_ceil(tile), 1),
151        block: (tile, tile, 1),
152        shared_mem: smem,
153    };
154
155    let a_ptr = a.as_ptr();
156    let grad_out_ptr = grad_output.as_ptr();
157    let grad_b_ptr = grad_b.as_ptr();
158
159    // PTX kernel signature: (a_ptr, grad_c_ptr, grad_b_ptr, m, n, k)
160    // CRITICAL: must match param declaration order in GemmBackwardBKernel::build_ptx()
161    let mut args: [*mut std::ffi::c_void; 6] = [
162        &a_ptr as *const _ as *mut _,
163        &grad_out_ptr as *const _ as *mut _,
164        &grad_b_ptr as *const _ as *mut _,
165        &m as *const _ as *mut _,
166        &n as *const _ as *mut _,
167        &k as *const _ as *mut _,
168    ];
169
170    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
171    // matching sizes, and the kernel parameters match the expected PTX signature.
172    unsafe {
173        stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
174            CudaTensorError::KernelError(format!("GEMM backward B launch failed: {e:?}"))
175        })?;
176    }
177
178    Ok(())
179}
180
181/// GEMM backward A with accumulation: grad_A += grad_C @ B^T (PMAT-484)
182///
183/// Adds result into grad_a instead of overwriting. Used for fused Gate+Up backward
184/// to eliminate the separate cuda_add_inplace kernel launch.
185#[cfg(feature = "cuda")]
186pub fn gemm_backward_a_accumulate(
187    grad_output: &GpuBuffer<f32>,
188    b: &GpuBuffer<f32>,
189    grad_a: &mut GpuBuffer<f32>,
190    m: u32,
191    k: u32,
192    n: u32,
193    _stream: &CudaStream,
194) -> Result<()> {
195    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
196    let cache = cache.lock().map_err(|_err| {
197        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
198    })?;
199
200    // cuBLAS accumulate path (beta=1.0) — this is the only path that matters
201    // in production since cuBLAS is always initialized for NF4 QLoRA training.
202    if let Some(cublas) = cache.cublas() {
203        return crate::autograd::cuda_forward::cublas_gemm_backward_a_accumulate(
204            cublas,
205            grad_output,
206            b,
207            grad_a,
208            m,
209            k,
210            n,
211        );
212    }
213
214    // No cuBLAS = no accumulation support. NF4 training requires cuBLAS.
215    Err(CudaTensorError::KernelError(
216        "gemm_backward_a_accumulate requires cuBLAS (NF4 training always has it)".to_string(),
217    ))
218}
219
220/// FP16-aware backward A with accumulation (PMAT-484): grad_A += grad_C @ B^T
221///
222/// Same as gemm_backward_a_fp16_dispatch but accumulates into grad_a.
223/// Used for fused Gate+Up backward to eliminate cuda_add_inplace.
224#[cfg(feature = "cuda")]
225pub fn gemm_backward_a_fp16_dispatch_accumulate(
226    grad_output: &GpuBuffer<f32>,
227    w_fp16: Option<&GpuBuffer<u16>>,
228    w_fp32: &GpuBuffer<f32>,
229    grad_a: &mut GpuBuffer<f32>,
230    m: u32,
231    k: u32,
232    n: u32,
233    stream: &CudaStream,
234    _ctx: &trueno_gpu::driver::CudaContext,
235) -> Result<()> {
236    // For fp16 path: compute into temp then add (cuBLAS fp16 doesn't easily support beta=1 mixed)
237    // For fp32 path: use cuBLAS beta=1.0 directly
238    if w_fp16.is_some() {
239        // FP16: compute into temp, then accumulate
240        let mut temp = GpuBuffer::<f32>::new(_ctx, (m * k) as usize)
241            .map_err(|e| CudaTensorError::AllocationFailed(format!("fp16 accum temp: {e:?}")))?;
242        gemm_backward_a_fp16_dispatch(
243            grad_output,
244            w_fp16,
245            w_fp32,
246            &mut temp,
247            m,
248            k,
249            n,
250            stream,
251            _ctx,
252        )?;
253        crate::transformer::cuda_block::cuda_add_inplace(grad_a, &temp, (m * k) as usize, stream)?;
254        Ok(())
255    } else {
256        gemm_backward_a_accumulate(grad_output, w_fp32, grad_a, m, k, n, stream)
257    }
258}
259
260/// FP16-aware backward A dispatch (PMAT-472): uses fp16 weights when available.
261///
262/// If `w_fp16` is Some, casts grad_output to fp16 and uses tensor core GEMM
263/// (fp16×fp16→fp32). Otherwise falls back to fp32. Eliminates fp32 weight
264/// storage — freeing ~2.6 GB VRAM for GPU embeddings on yoga 8GB.
265#[cfg(feature = "cuda")]
266pub fn gemm_backward_a_fp16_dispatch(
267    grad_output: &GpuBuffer<f32>,
268    w_fp16: Option<&GpuBuffer<u16>>,
269    w_fp32: &GpuBuffer<f32>,
270    grad_a: &mut GpuBuffer<f32>,
271    m: u32,
272    k: u32,
273    n: u32,
274    stream: &CudaStream,
275    ctx: &trueno_gpu::driver::CudaContext,
276) -> Result<()> {
277    if let Some(w16) = w_fp16 {
278        let elems = (m * n) as usize;
279        let mut grad_f16 = GpuBuffer::<u16>::new(ctx, elems)
280            .map_err(|e| CudaTensorError::AllocationFailed(format!("grad f16 cast: {e:?}")))?;
281        crate::autograd::cuda_forward::cast_f32_to_f16_gpu(
282            grad_output,
283            &mut grad_f16,
284            m * n,
285            stream,
286        )?;
287        crate::autograd::cuda_forward::gemm_f16_to_f32_backward_a(
288            &grad_f16, w16, grad_a, m, k, n, stream,
289        )
290    } else {
291        gemm_backward_a(grad_output, w_fp32, grad_a, m, k, n, stream)
292    }
293}