Skip to main content

entrenar/autograd/cuda_forward/
matmul_f16.rs

1//! FP16 cuBLAS GEMM operations for training
2//!
3//! Contract: fp16-cublas-gemm-v1.yaml (PMAT-458)
4//!
5//! FP16 GEMM uses tensor cores on sm_89+ (83 TFLOPS vs 2 TFLOPS SIMD).
6//! All matrices are FP16 (CUDA_R_16F) with FP32 accumulation (CUBLAS_COMPUTE_32F).
7//! Expected: ~2x throughput vs fp32 on memory-BW-bound workloads.
8//!
9//! # Safety
10//!
11//! Backward GEMMs use `gemm_f16()` which internally uses CUBLAS_COMPUTE_32F
12//! (FP32 accumulation). This is safe for transposed backward GEMMs — unlike
13//! TF32 tensor cores which produce NaN at gradient magnitude ~1e5 (ALB-076).
14
15#![allow(unsafe_code)]
16#![allow(trivial_casts)]
17#![allow(clippy::borrow_as_ptr)]
18#![allow(clippy::ref_as_ptr)]
19
20#[cfg(feature = "cuda")]
21use trueno_gpu::driver::{CublasHandle, CudaStream, GemmOp, GpuBuffer};
22
23use crate::autograd::cuda_tensor::{CudaTensorError, Result};
24
25#[cfg(feature = "cuda")]
26use super::cache::FORWARD_KERNEL_CACHE;
27
28/// FP16 cuBLAS GEMM forward: C[M,N] = A[M,K] @ B[K,N] using tensor cores
29///
30/// Contract: fp16-cublas-gemm-v1.yaml C-FP16GEMM-001 (PMAT-458)
31/// All matrices are FP16 (CUDA_R_16F). Accumulation in FP32 (CUBLAS_COMPUTE_32F).
32/// Tensor cores activated via CUBLAS_GEMM_DEFAULT_TENSOR_OP.
33/// Expected: ~2x throughput vs fp32 on memory-BW-bound workloads (RTX 4060L).
34#[cfg(feature = "cuda")]
35pub fn gemm_forward_f16(
36    a: &GpuBuffer<u16>,
37    b: &GpuBuffer<u16>,
38    c: &mut GpuBuffer<u16>,
39    m: u32,
40    k: u32,
41    n: u32,
42    stream: &CudaStream,
43) -> Result<()> {
44    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
45    let cache = cache.lock().map_err(|_err| {
46        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
47    })?;
48    let cublas = cache.cublas().ok_or_else(|| {
49        CudaTensorError::KernelError("cuBLAS handle required for fp16 GEMM".to_string())
50    })?;
51    let _ = stream; // cuBLAS handle already bound to stream
52    cublas
53        .gemm_f16(
54            GemmOp::NoTrans,
55            GemmOp::NoTrans,
56            n as i32,
57            m as i32,
58            k as i32,
59            1.0,
60            b.as_ptr(),
61            n as i32,
62            a.as_ptr(),
63            k as i32,
64            0.0,
65            c.as_ptr(),
66            n as i32,
67        )
68        .map_err(|e| {
69            CudaTensorError::KernelError(format!("cuBLAS fp16 GEMM forward failed: {e:?}"))
70        })
71}
72
73/// FP16 cuBLAS backward A: grad_A[M,K] = grad_C[M,N] @ B[K,N]^T (tensor cores)
74///
75/// Contract: fp16-cublas-gemm-v1.yaml C-FP16GEMM-002 (PMAT-458)
76/// Gradient GEMM uses fp16 for memory bandwidth savings. Gradient accumulation
77/// should be promoted to fp32 in the caller to prevent underflow.
78/// Note: trueno gemm_f16 uses CUBLAS_COMPUTE_32F (fp32 accumulation), which
79/// is safe for transposed backward GEMMs (unlike TF32 per ALB-076).
80#[cfg(feature = "cuda")]
81pub(crate) fn cublas_gemm_backward_a_f16(
82    cublas: &CublasHandle,
83    grad_output: &GpuBuffer<u16>,
84    b: &GpuBuffer<u16>,
85    grad_a: &mut GpuBuffer<u16>,
86    m: u32,
87    k: u32,
88    n: u32,
89) -> Result<()> {
90    cublas
91        .gemm_f16(
92            GemmOp::Trans,
93            GemmOp::NoTrans,
94            k as i32,
95            m as i32,
96            n as i32,
97            1.0,
98            b.as_ptr(),
99            n as i32,
100            grad_output.as_ptr(),
101            n as i32,
102            0.0,
103            grad_a.as_ptr(),
104            k as i32,
105        )
106        .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS fp16 backward_a failed: {e:?}")))
107}
108
109/// FP16 cuBLAS backward B: grad_B[K,N] = A[M,K]^T @ grad_C[M,N] (tensor cores)
110///
111/// Contract: fp16-cublas-gemm-v1.yaml C-FP16GEMM-002 (PMAT-458)
112#[cfg(feature = "cuda")]
113pub(crate) fn cublas_gemm_backward_b_f16(
114    cublas: &CublasHandle,
115    a: &GpuBuffer<u16>,
116    grad_output: &GpuBuffer<u16>,
117    grad_b: &mut GpuBuffer<u16>,
118    m: u32,
119    k: u32,
120    n: u32,
121) -> Result<()> {
122    cublas
123        .gemm_f16(
124            GemmOp::NoTrans,
125            GemmOp::Trans,
126            n as i32,
127            k as i32,
128            m as i32,
129            1.0,
130            grad_output.as_ptr(),
131            n as i32,
132            a.as_ptr(),
133            k as i32,
134            0.0,
135            grad_b.as_ptr(),
136            n as i32,
137        )
138        .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS fp16 backward_b failed: {e:?}")))
139}
140
141/// Mixed-precision backward_a: grad_A(fp32) = grad_C(fp16) @ B(fp16)^T (tensor cores)
142///
143/// Contract: C-FP16GEMM-002 (PMAT-472)
144/// Enables dropping fp32 weights: backward uses fp16 weights with fp32 accumulation.
145/// Cast grad_output fp32→fp16 at call site, pass fp16 weights, get fp32 grad_input.
146#[cfg(feature = "cuda")]
147pub fn gemm_f16_to_f32_backward_a(
148    grad_output: &GpuBuffer<u16>,
149    b: &GpuBuffer<u16>,
150    grad_a: &mut GpuBuffer<f32>,
151    m: u32,
152    k: u32,
153    n: u32,
154    stream: &CudaStream,
155) -> Result<()> {
156    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
157    let cache = cache.lock().map_err(|_err| {
158        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
159    })?;
160    let cublas = cache.cublas().ok_or_else(|| {
161        CudaTensorError::KernelError("cuBLAS handle required for fp16→fp32 backward".to_string())
162    })?;
163    let _ = stream;
164    cublas
165        .gemm_f16_to_f32(
166            GemmOp::Trans,
167            GemmOp::NoTrans,
168            k as i32,
169            m as i32,
170            n as i32,
171            1.0,
172            b.as_ptr(),
173            n as i32,
174            grad_output.as_ptr(),
175            n as i32,
176            0.0,
177            grad_a.as_ptr(),
178            k as i32,
179        )
180        .map_err(|e| {
181            CudaTensorError::KernelError(format!("cuBLAS fp16→fp32 backward_a failed: {e:?}"))
182        })
183}
184
185/// Mixed-precision GEMM: C(fp32) = A(fp16) @ B(fp16) using tensor cores
186///
187/// Contract: fp16-cublas-gemm-v1.yaml C-FP16GEMM-001 (PMAT-470)
188/// A and B are FP16 (weights and activations cast to fp16). C is FP32.
189/// Uses CUBLAS_COMPUTE_32F with CUBLAS_GEMM_DEFAULT_TENSOR_OP.
190/// This is the "practical FP16 path": cast fp32 activations to fp16,
191/// multiply by fp16 weights, produce fp32 output for the rest of the pipeline.
192#[cfg(feature = "cuda")]
193pub fn gemm_f16_to_f32_forward(
194    a: &GpuBuffer<u16>,
195    b: &GpuBuffer<u16>,
196    c: &mut GpuBuffer<f32>,
197    m: u32,
198    k: u32,
199    n: u32,
200    stream: &CudaStream,
201) -> Result<()> {
202    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
203    let cache = cache.lock().map_err(|_err| {
204        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
205    })?;
206    let cublas = cache.cublas().ok_or_else(|| {
207        CudaTensorError::KernelError("cuBLAS handle required for fp16→fp32 GEMM".to_string())
208    })?;
209    let _ = stream;
210    cublas
211        .gemm_f16_to_f32(
212            GemmOp::NoTrans,
213            GemmOp::NoTrans,
214            n as i32,
215            m as i32,
216            k as i32,
217            1.0,
218            b.as_ptr(),
219            n as i32,
220            a.as_ptr(),
221            k as i32,
222            0.0,
223            c.as_ptr(),
224            n as i32,
225        )
226        .map_err(|e| {
227            CudaTensorError::KernelError(format!("cuBLAS fp16→fp32 GEMM forward failed: {e:?}"))
228        })
229}