entrenar/autograd/cuda_forward/
bf16_cast.rs1#![allow(unsafe_code)]
16#![allow(trivial_casts)]
17#![allow(clippy::borrow_as_ptr)]
18#![allow(clippy::ref_as_ptr)]
19
20#[cfg(feature = "cuda")]
21use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
22#[cfg(feature = "cuda")]
23use trueno_gpu::ptx::{PtxArithmetic, PtxComparison, PtxControl, PtxKernel, PtxModule, PtxType};
24
25use crate::autograd::cuda_tensor::{CudaTensorError, Result};
26
27#[cfg(feature = "cuda")]
28use super::cache::FORWARD_KERNEL_CACHE;
29
30#[cfg(feature = "cuda")]
35fn build_cast_f32_to_bf16_ptx(_n: u32) -> String {
36 let kernel = PtxKernel::new("cast_f32_to_bf16")
37 .param(PtxType::U64, "src_ptr")
38 .param(PtxType::U64, "dst_ptr")
39 .param(PtxType::U32, "n")
40 .build(|ctx| {
41 let ctaid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::CtaIdX);
42 let ntid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::NtidX);
43 let tid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::TidX);
44
45 let idx = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
46 let n_param = ctx.load_param_u32("n");
47 let pred = ctx.setp_ge_u32(idx, n_param);
48 ctx.branch_if(pred, "exit");
49
50 let src_ptr = ctx.load_param_u64("src_ptr");
51 let dst_ptr = ctx.load_param_u64("dst_ptr");
52
53 let offset = ctx.mul_wide_u32(idx, 4);
55 let addr = ctx.add_u64(src_ptr, offset);
56 let bits = ctx.ld_global_u32(addr);
57
58 let bf16_bits = ctx.shr_u32_imm(bits, 16);
60
61 let dst_offset = ctx.mul_wide_u32(idx, 2);
63 let dst_addr = ctx.add_u64(dst_ptr, dst_offset);
64 ctx.st_global_u16(dst_addr, bf16_bits);
65
66 ctx.label("exit");
67 ctx.ret();
68 });
69 PtxModule::new().target("sm_70").add_kernel(kernel).emit()
70}
71
72#[cfg(feature = "cuda")]
77fn build_cast_bf16_to_f32_ptx(_n: u32) -> String {
78 let kernel = PtxKernel::new("cast_bf16_to_f32")
79 .param(PtxType::U64, "src_ptr")
80 .param(PtxType::U64, "dst_ptr")
81 .param(PtxType::U32, "n")
82 .build(|ctx| {
83 let ctaid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::CtaIdX);
84 let ntid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::NtidX);
85 let tid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::TidX);
86
87 let idx = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
88 let n_param = ctx.load_param_u32("n");
89 let pred = ctx.setp_ge_u32(idx, n_param);
90 ctx.branch_if(pred, "exit");
91
92 let src_ptr = ctx.load_param_u64("src_ptr");
93 let dst_ptr = ctx.load_param_u64("dst_ptr");
94
95 let src_offset = ctx.mul_wide_u32(idx, 2);
97 let src_addr = ctx.add_u64(src_ptr, src_offset);
98 let bf16_bits = ctx.ld_global_u16(src_addr);
99
100 let f32_bits = ctx.shl_u32_imm(bf16_bits, 16);
102
103 let dst_offset = ctx.mul_wide_u32(idx, 4);
105 let dst_addr = ctx.add_u64(dst_ptr, dst_offset);
106 ctx.st_global_u32(dst_addr, f32_bits);
107
108 ctx.label("exit");
109 ctx.ret();
110 });
111 PtxModule::new().target("sm_70").add_kernel(kernel).emit()
112}
113
114#[cfg(feature = "cuda")]
122pub fn cast_f32_to_bf16_gpu(
123 src: &GpuBuffer<f32>,
124 dst: &mut GpuBuffer<u16>,
125 n: u32,
126 stream: &CudaStream,
127) -> Result<()> {
128 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
129 let mut cache = cache.lock().map_err(|_err| {
130 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
131 })?;
132
133 let key = "cast_f32_to_bf16";
134 let module = match cache.get_cached(key) {
135 Some(m) => m,
136 None => {
137 let ptx = build_cast_f32_to_bf16_ptx(0);
138 cache.get_or_compile(key, &ptx)?
139 }
140 };
141
142 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
143
144 let src_ptr = src.as_ptr();
145 let dst_ptr = dst.as_ptr();
146
147 let mut args: [*mut std::ffi::c_void; 3] =
148 [&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
149
150 unsafe {
153 stream.launch_kernel(module, "cast_f32_to_bf16", &config, &mut args).map_err(|e| {
154 CudaTensorError::KernelError(format!("cast_f32_to_bf16 launch failed: {e:?}"))
155 })?;
156 }
157
158 Ok(())
159}
160
161#[cfg(feature = "cuda")]
169pub fn cast_bf16_to_f32_gpu(
170 src: &GpuBuffer<u16>,
171 dst: &mut GpuBuffer<f32>,
172 n: u32,
173 stream: &CudaStream,
174) -> Result<()> {
175 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
176 let mut cache = cache.lock().map_err(|_err| {
177 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
178 })?;
179
180 let key = "cast_bf16_to_f32";
181 let module = match cache.get_cached(key) {
182 Some(m) => m,
183 None => {
184 let ptx = build_cast_bf16_to_f32_ptx(0);
185 cache.get_or_compile(key, &ptx)?
186 }
187 };
188
189 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
190
191 let src_ptr = src.as_ptr();
192 let dst_ptr = dst.as_ptr();
193
194 let mut args: [*mut std::ffi::c_void; 3] =
195 [&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
196
197 unsafe {
199 stream.launch_kernel(module, "cast_bf16_to_f32", &config, &mut args).map_err(|e| {
200 CudaTensorError::KernelError(format!("cast_bf16_to_f32 launch failed: {e:?}"))
201 })?;
202 }
203
204 Ok(())
205}
206
207pub fn f32_slice_to_bf16(src: &[f32]) -> Vec<half::bf16> {
211 src.iter().map(|&v| half::bf16::from_f32(v)).collect()
212}
213
214pub fn bf16_slice_to_f32(src: &[half::bf16]) -> Vec<f32> {
216 src.iter().map(|v| v.to_f32()).collect()
217}
218
219#[cfg(feature = "cuda")]
224pub fn cast_f32_to_f16_gpu(
225 src: &GpuBuffer<f32>,
226 dst: &mut GpuBuffer<u16>,
227 n: u32,
228 stream: &CudaStream,
229) -> Result<()> {
230 use trueno_gpu::kernels::{CastF32ToF16Kernel, Kernel};
231
232 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
233 let mut cache = cache.lock().map_err(|_err| {
234 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
235 })?;
236
237 let key = "cast_f32_to_f16";
238 let module = match cache.get_cached(key) {
239 Some(m) => m,
240 None => {
241 let kernel = CastF32ToF16Kernel;
242 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
243 cache.get_or_compile(key, &ptx)?
244 }
245 };
246
247 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
248
249 let src_ptr = src.as_ptr();
250 let dst_ptr = dst.as_ptr();
251
252 let mut args: [*mut std::ffi::c_void; 3] =
253 [&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
254
255 unsafe {
256 stream
257 .launch_kernel(module, "cast_f32_to_f16", &config, &mut args)
258 .map_err(|e| CudaTensorError::KernelError(format!("f32→f16 cast failed: {e:?}")))?;
259 }
260
261 Ok(())
262}