entrenar/autograd/cuda_backward/
gemm.rs1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::backward::{GemmBackwardAKernel, GemmBackwardBKernel};
10#[cfg(feature = "cuda")]
11use trueno_gpu::kernels::Kernel;
12
13use super::super::cuda_tensor::{CudaTensorError, Result};
14#[cfg(feature = "cuda")]
15use super::cache::KERNEL_CACHE;
16
17#[cfg(feature = "cuda")]
19use crate::autograd::cuda_forward::{cublas_gemm_backward_a, cublas_gemm_backward_b};
20
21const BACKWARD_TILE_SIZE: u32 = 16;
26
27#[cfg(feature = "cuda")]
33pub fn gemm_backward_a(
34 grad_output: &GpuBuffer<f32>,
35 b: &GpuBuffer<f32>,
36 grad_a: &mut GpuBuffer<f32>,
37 m: u32,
38 k: u32,
39 n: u32,
40 stream: &CudaStream,
41) -> Result<()> {
42 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
43 let mut cache = cache.lock().map_err(|_err| {
44 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
45 })?;
46
47 if let Some(cublas) = cache.cublas() {
53 return cublas_gemm_backward_a(cublas, grad_output, b, grad_a, m, k, n);
54 }
55
56 let tile = BACKWARD_TILE_SIZE;
57 let kernel = GemmBackwardAKernel::tiled_unrolled(m, n, k, tile);
59 let kernel_name = kernel.name();
60
61 let key = format!("gemm_backward_a_{m}_{k}_{n}");
62 let module = match cache.get_cached(&key) {
63 Some(m) => m,
64 None => {
65 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
66 cache.get_or_compile(&key, &ptx)?
67 }
68 };
69
70 let smem = 2 * tile * tile * 4; let config = LaunchConfig {
73 grid: (k.div_ceil(tile), m.div_ceil(tile), 1),
74 block: (tile, tile, 1),
75 shared_mem: smem,
76 };
77
78 let grad_out_ptr = grad_output.as_ptr();
79 let b_ptr = b.as_ptr();
80 let grad_a_ptr = grad_a.as_ptr();
81
82 let mut args: [*mut std::ffi::c_void; 6] = [
85 &grad_out_ptr as *const _ as *mut _,
86 &b_ptr as *const _ as *mut _,
87 &grad_a_ptr as *const _ as *mut _,
88 &m as *const _ as *mut _,
89 &n as *const _ as *mut _,
90 &k as *const _ as *mut _,
91 ];
92
93 unsafe {
96 stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
97 CudaTensorError::KernelError(format!("GEMM backward A launch failed: {e:?}"))
98 })?;
99 }
100
101 Ok(())
102}
103
104#[cfg(feature = "cuda")]
110pub fn gemm_backward_b(
111 a: &GpuBuffer<f32>,
112 grad_output: &GpuBuffer<f32>,
113 grad_b: &mut GpuBuffer<f32>,
114 m: u32,
115 k: u32,
116 n: u32,
117 stream: &CudaStream,
118) -> Result<()> {
119 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
120 let mut cache = cache.lock().map_err(|_err| {
121 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
122 })?;
123
124 if let Some(cublas) = cache.cublas() {
130 return cublas_gemm_backward_b(cublas, a, grad_output, grad_b, m, k, n);
131 }
132
133 let tile = BACKWARD_TILE_SIZE;
134 let kernel = GemmBackwardBKernel::tiled_unrolled(m, n, k, tile);
136 let kernel_name = kernel.name();
137
138 let key = format!("gemm_backward_b_{m}_{k}_{n}");
139 let module = match cache.get_cached(&key) {
140 Some(m) => m,
141 None => {
142 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
143 cache.get_or_compile(&key, &ptx)?
144 }
145 };
146
147 let smem = 2 * tile * tile * 4;
149 let config = LaunchConfig {
150 grid: (n.div_ceil(tile), k.div_ceil(tile), 1),
151 block: (tile, tile, 1),
152 shared_mem: smem,
153 };
154
155 let a_ptr = a.as_ptr();
156 let grad_out_ptr = grad_output.as_ptr();
157 let grad_b_ptr = grad_b.as_ptr();
158
159 let mut args: [*mut std::ffi::c_void; 6] = [
162 &a_ptr as *const _ as *mut _,
163 &grad_out_ptr as *const _ as *mut _,
164 &grad_b_ptr as *const _ as *mut _,
165 &m as *const _ as *mut _,
166 &n as *const _ as *mut _,
167 &k as *const _ as *mut _,
168 ];
169
170 unsafe {
173 stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
174 CudaTensorError::KernelError(format!("GEMM backward B launch failed: {e:?}"))
175 })?;
176 }
177
178 Ok(())
179}
180
181#[cfg(feature = "cuda")]
186pub fn gemm_backward_a_accumulate(
187 grad_output: &GpuBuffer<f32>,
188 b: &GpuBuffer<f32>,
189 grad_a: &mut GpuBuffer<f32>,
190 m: u32,
191 k: u32,
192 n: u32,
193 _stream: &CudaStream,
194) -> Result<()> {
195 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
196 let cache = cache.lock().map_err(|_err| {
197 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
198 })?;
199
200 if let Some(cublas) = cache.cublas() {
203 return crate::autograd::cuda_forward::cublas_gemm_backward_a_accumulate(
204 cublas,
205 grad_output,
206 b,
207 grad_a,
208 m,
209 k,
210 n,
211 );
212 }
213
214 Err(CudaTensorError::KernelError(
216 "gemm_backward_a_accumulate requires cuBLAS (NF4 training always has it)".to_string(),
217 ))
218}
219
220#[cfg(feature = "cuda")]
225pub fn gemm_backward_a_fp16_dispatch_accumulate(
226 grad_output: &GpuBuffer<f32>,
227 w_fp16: Option<&GpuBuffer<u16>>,
228 w_fp32: &GpuBuffer<f32>,
229 grad_a: &mut GpuBuffer<f32>,
230 m: u32,
231 k: u32,
232 n: u32,
233 stream: &CudaStream,
234 _ctx: &trueno_gpu::driver::CudaContext,
235) -> Result<()> {
236 if w_fp16.is_some() {
239 let mut temp = GpuBuffer::<f32>::new(_ctx, (m * k) as usize)
241 .map_err(|e| CudaTensorError::AllocationFailed(format!("fp16 accum temp: {e:?}")))?;
242 gemm_backward_a_fp16_dispatch(
243 grad_output,
244 w_fp16,
245 w_fp32,
246 &mut temp,
247 m,
248 k,
249 n,
250 stream,
251 _ctx,
252 )?;
253 crate::transformer::cuda_block::cuda_add_inplace(grad_a, &temp, (m * k) as usize, stream)?;
254 Ok(())
255 } else {
256 gemm_backward_a_accumulate(grad_output, w_fp32, grad_a, m, k, n, stream)
257 }
258}
259
260#[cfg(feature = "cuda")]
266pub fn gemm_backward_a_fp16_dispatch(
267 grad_output: &GpuBuffer<f32>,
268 w_fp16: Option<&GpuBuffer<u16>>,
269 w_fp32: &GpuBuffer<f32>,
270 grad_a: &mut GpuBuffer<f32>,
271 m: u32,
272 k: u32,
273 n: u32,
274 stream: &CudaStream,
275 ctx: &trueno_gpu::driver::CudaContext,
276) -> Result<()> {
277 if let Some(w16) = w_fp16 {
278 let elems = (m * n) as usize;
279 let mut grad_f16 = GpuBuffer::<u16>::new(ctx, elems)
280 .map_err(|e| CudaTensorError::AllocationFailed(format!("grad f16 cast: {e:?}")))?;
281 crate::autograd::cuda_forward::cast_f32_to_f16_gpu(
282 grad_output,
283 &mut grad_f16,
284 m * n,
285 stream,
286 )?;
287 crate::autograd::cuda_forward::gemm_f16_to_f32_backward_a(
288 &grad_f16, w16, grad_a, m, k, n, stream,
289 )
290 } else {
291 gemm_backward_a(grad_output, w_fp32, grad_a, m, k, n, stream)
292 }
293}