#![allow(unsafe_code)]
#![allow(trivial_casts)]
#![allow(clippy::borrow_as_ptr)]
#![allow(clippy::ref_as_ptr)]
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
#[cfg(feature = "cuda")]
use trueno_gpu::ptx::{PtxArithmetic, PtxComparison, PtxControl, PtxKernel, PtxModule, PtxType};
use crate::autograd::cuda_tensor::{CudaTensorError, Result};
#[cfg(feature = "cuda")]
use super::cache::FORWARD_KERNEL_CACHE;
#[cfg(feature = "cuda")]
fn build_cast_f32_to_bf16_ptx(_n: u32) -> String {
let kernel = PtxKernel::new("cast_f32_to_bf16")
.param(PtxType::U64, "src_ptr")
.param(PtxType::U64, "dst_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let ctaid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::CtaIdX);
let ntid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::NtidX);
let tid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::TidX);
let idx = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let n_param = ctx.load_param_u32("n");
let pred = ctx.setp_ge_u32(idx, n_param);
ctx.branch_if(pred, "exit");
let src_ptr = ctx.load_param_u64("src_ptr");
let dst_ptr = ctx.load_param_u64("dst_ptr");
let offset = ctx.mul_wide_u32(idx, 4);
let addr = ctx.add_u64(src_ptr, offset);
let bits = ctx.ld_global_u32(addr);
let bf16_bits = ctx.shr_u32_imm(bits, 16);
let dst_offset = ctx.mul_wide_u32(idx, 2);
let dst_addr = ctx.add_u64(dst_ptr, dst_offset);
ctx.st_global_u16(dst_addr, bf16_bits);
ctx.label("exit");
ctx.ret();
});
PtxModule::new().target("sm_70").add_kernel(kernel).emit()
}
#[cfg(feature = "cuda")]
fn build_cast_bf16_to_f32_ptx(_n: u32) -> String {
let kernel = PtxKernel::new("cast_bf16_to_f32")
.param(PtxType::U64, "src_ptr")
.param(PtxType::U64, "dst_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let ctaid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::CtaIdX);
let ntid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::NtidX);
let tid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::TidX);
let idx = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
let n_param = ctx.load_param_u32("n");
let pred = ctx.setp_ge_u32(idx, n_param);
ctx.branch_if(pred, "exit");
let src_ptr = ctx.load_param_u64("src_ptr");
let dst_ptr = ctx.load_param_u64("dst_ptr");
let src_offset = ctx.mul_wide_u32(idx, 2);
let src_addr = ctx.add_u64(src_ptr, src_offset);
let bf16_bits = ctx.ld_global_u16(src_addr);
let f32_bits = ctx.shl_u32_imm(bf16_bits, 16);
let dst_offset = ctx.mul_wide_u32(idx, 4);
let dst_addr = ctx.add_u64(dst_ptr, dst_offset);
ctx.st_global_u32(dst_addr, f32_bits);
ctx.label("exit");
ctx.ret();
});
PtxModule::new().target("sm_70").add_kernel(kernel).emit()
}
#[cfg(feature = "cuda")]
pub fn cast_f32_to_bf16_gpu(
src: &GpuBuffer<f32>,
dst: &mut GpuBuffer<u16>,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "cast_f32_to_bf16";
let module = match cache.get_cached(key) {
Some(m) => m,
None => {
let ptx = build_cast_f32_to_bf16_ptx(0);
cache.get_or_compile(key, &ptx)?
}
};
let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let src_ptr = src.as_ptr();
let dst_ptr = dst.as_ptr();
let mut args: [*mut std::ffi::c_void; 3] =
[&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
unsafe {
stream.launch_kernel(module, "cast_f32_to_bf16", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("cast_f32_to_bf16 launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn cast_bf16_to_f32_gpu(
src: &GpuBuffer<u16>,
dst: &mut GpuBuffer<f32>,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "cast_bf16_to_f32";
let module = match cache.get_cached(key) {
Some(m) => m,
None => {
let ptx = build_cast_bf16_to_f32_ptx(0);
cache.get_or_compile(key, &ptx)?
}
};
let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let src_ptr = src.as_ptr();
let dst_ptr = dst.as_ptr();
let mut args: [*mut std::ffi::c_void; 3] =
[&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
unsafe {
stream.launch_kernel(module, "cast_bf16_to_f32", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("cast_bf16_to_f32 launch failed: {e:?}"))
})?;
}
Ok(())
}
pub fn f32_slice_to_bf16(src: &[f32]) -> Vec<half::bf16> {
src.iter().map(|&v| half::bf16::from_f32(v)).collect()
}
pub fn bf16_slice_to_f32(src: &[half::bf16]) -> Vec<f32> {
src.iter().map(|v| v.to_f32()).collect()
}
#[cfg(feature = "cuda")]
pub fn cast_f32_to_f16_gpu(
src: &GpuBuffer<f32>,
dst: &mut GpuBuffer<u16>,
n: u32,
stream: &CudaStream,
) -> Result<()> {
use trueno_gpu::kernels::{CastF32ToF16Kernel, Kernel};
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "cast_f32_to_f16";
let module = match cache.get_cached(key) {
Some(m) => m,
None => {
let kernel = CastF32ToF16Kernel;
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(key, &ptx)?
}
};
let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let src_ptr = src.as_ptr();
let dst_ptr = dst.as_ptr();
let mut args: [*mut std::ffi::c_void; 3] =
[&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
unsafe {
stream
.launch_kernel(module, "cast_f32_to_f16", &config, &mut args)
.map_err(|e| CudaTensorError::KernelError(format!("f32→f16 cast failed: {e:?}")))?;
}
Ok(())
}