entrenar/autograd/cuda_forward/
matmul_f16.rs1#![allow(unsafe_code)]
16#![allow(trivial_casts)]
17#![allow(clippy::borrow_as_ptr)]
18#![allow(clippy::ref_as_ptr)]
19
20#[cfg(feature = "cuda")]
21use trueno_gpu::driver::{CublasHandle, CudaStream, GemmOp, GpuBuffer};
22
23use crate::autograd::cuda_tensor::{CudaTensorError, Result};
24
25#[cfg(feature = "cuda")]
26use super::cache::FORWARD_KERNEL_CACHE;
27
28#[cfg(feature = "cuda")]
35pub fn gemm_forward_f16(
36 a: &GpuBuffer<u16>,
37 b: &GpuBuffer<u16>,
38 c: &mut GpuBuffer<u16>,
39 m: u32,
40 k: u32,
41 n: u32,
42 stream: &CudaStream,
43) -> Result<()> {
44 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
45 let cache = cache.lock().map_err(|_err| {
46 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
47 })?;
48 let cublas = cache.cublas().ok_or_else(|| {
49 CudaTensorError::KernelError("cuBLAS handle required for fp16 GEMM".to_string())
50 })?;
51 let _ = stream; cublas
53 .gemm_f16(
54 GemmOp::NoTrans,
55 GemmOp::NoTrans,
56 n as i32,
57 m as i32,
58 k as i32,
59 1.0,
60 b.as_ptr(),
61 n as i32,
62 a.as_ptr(),
63 k as i32,
64 0.0,
65 c.as_ptr(),
66 n as i32,
67 )
68 .map_err(|e| {
69 CudaTensorError::KernelError(format!("cuBLAS fp16 GEMM forward failed: {e:?}"))
70 })
71}
72
73#[cfg(feature = "cuda")]
81pub(crate) fn cublas_gemm_backward_a_f16(
82 cublas: &CublasHandle,
83 grad_output: &GpuBuffer<u16>,
84 b: &GpuBuffer<u16>,
85 grad_a: &mut GpuBuffer<u16>,
86 m: u32,
87 k: u32,
88 n: u32,
89) -> Result<()> {
90 cublas
91 .gemm_f16(
92 GemmOp::Trans,
93 GemmOp::NoTrans,
94 k as i32,
95 m as i32,
96 n as i32,
97 1.0,
98 b.as_ptr(),
99 n as i32,
100 grad_output.as_ptr(),
101 n as i32,
102 0.0,
103 grad_a.as_ptr(),
104 k as i32,
105 )
106 .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS fp16 backward_a failed: {e:?}")))
107}
108
109#[cfg(feature = "cuda")]
113pub(crate) fn cublas_gemm_backward_b_f16(
114 cublas: &CublasHandle,
115 a: &GpuBuffer<u16>,
116 grad_output: &GpuBuffer<u16>,
117 grad_b: &mut GpuBuffer<u16>,
118 m: u32,
119 k: u32,
120 n: u32,
121) -> Result<()> {
122 cublas
123 .gemm_f16(
124 GemmOp::NoTrans,
125 GemmOp::Trans,
126 n as i32,
127 k as i32,
128 m as i32,
129 1.0,
130 grad_output.as_ptr(),
131 n as i32,
132 a.as_ptr(),
133 k as i32,
134 0.0,
135 grad_b.as_ptr(),
136 n as i32,
137 )
138 .map_err(|e| CudaTensorError::KernelError(format!("cuBLAS fp16 backward_b failed: {e:?}")))
139}
140
141#[cfg(feature = "cuda")]
147pub fn gemm_f16_to_f32_backward_a(
148 grad_output: &GpuBuffer<u16>,
149 b: &GpuBuffer<u16>,
150 grad_a: &mut GpuBuffer<f32>,
151 m: u32,
152 k: u32,
153 n: u32,
154 stream: &CudaStream,
155) -> Result<()> {
156 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
157 let cache = cache.lock().map_err(|_err| {
158 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
159 })?;
160 let cublas = cache.cublas().ok_or_else(|| {
161 CudaTensorError::KernelError("cuBLAS handle required for fp16→fp32 backward".to_string())
162 })?;
163 let _ = stream;
164 cublas
165 .gemm_f16_to_f32(
166 GemmOp::Trans,
167 GemmOp::NoTrans,
168 k as i32,
169 m as i32,
170 n as i32,
171 1.0,
172 b.as_ptr(),
173 n as i32,
174 grad_output.as_ptr(),
175 n as i32,
176 0.0,
177 grad_a.as_ptr(),
178 k as i32,
179 )
180 .map_err(|e| {
181 CudaTensorError::KernelError(format!("cuBLAS fp16→fp32 backward_a failed: {e:?}"))
182 })
183}
184
185#[cfg(feature = "cuda")]
193pub fn gemm_f16_to_f32_forward(
194 a: &GpuBuffer<u16>,
195 b: &GpuBuffer<u16>,
196 c: &mut GpuBuffer<f32>,
197 m: u32,
198 k: u32,
199 n: u32,
200 stream: &CudaStream,
201) -> Result<()> {
202 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
203 let cache = cache.lock().map_err(|_err| {
204 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
205 })?;
206 let cublas = cache.cublas().ok_or_else(|| {
207 CudaTensorError::KernelError("cuBLAS handle required for fp16→fp32 GEMM".to_string())
208 })?;
209 let _ = stream;
210 cublas
211 .gemm_f16_to_f32(
212 GemmOp::NoTrans,
213 GemmOp::NoTrans,
214 n as i32,
215 m as i32,
216 k as i32,
217 1.0,
218 b.as_ptr(),
219 n as i32,
220 a.as_ptr(),
221 k as i32,
222 0.0,
223 c.as_ptr(),
224 n as i32,
225 )
226 .map_err(|e| {
227 CudaTensorError::KernelError(format!("cuBLAS fp16→fp32 GEMM forward failed: {e:?}"))
228 })
229}