Skip to main content

entrenar/autograd/cuda_forward/
bf16_cast.rs

1//! GPU f32↔bf16 cast kernels (R-002: BF16 mixed precision foundation)
2//!
3//! Provides element-wise conversion between f32 and bf16 on GPU.
4//! BF16 uses the same 8-bit exponent as f32 but only 7 mantissa bits,
5//! so conversion is a simple truncation (f32→bf16) or zero-extension (bf16→f32).
6//!
7//! # Contract (C-BF16CAST-001)
8//!
9//! - `cast_f32_to_bf16`: output[i] == truncate(input[i]) for all i in [0, n)
10//! - `cast_bf16_to_f32`: output[i] == extend(input[i]) for all i in [0, n)
11//! - Round-trip: `cast_bf16_to_f32(cast_f32_to_bf16(x))` preserves f32 values
12//!   within BF16 representable range (7-bit mantissa precision)
13//! - NaN/Inf preserved through both conversions
14
15#![allow(unsafe_code)]
16#![allow(trivial_casts)]
17#![allow(clippy::borrow_as_ptr)]
18#![allow(clippy::ref_as_ptr)]
19
20#[cfg(feature = "cuda")]
21use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
22#[cfg(feature = "cuda")]
23use trueno_gpu::ptx::{PtxArithmetic, PtxComparison, PtxControl, PtxKernel, PtxModule, PtxType};
24
25use crate::autograd::cuda_tensor::{CudaTensorError, Result};
26
27#[cfg(feature = "cuda")]
28use super::cache::FORWARD_KERNEL_CACHE;
29
30/// Build PTX kernel for f32 → bf16 cast.
31///
32/// Each thread converts one element: loads f32 as u32 bits, takes upper 16 bits
33/// (sign + 8-bit exponent + 7-bit mantissa), stores as u16.
34#[cfg(feature = "cuda")]
35fn build_cast_f32_to_bf16_ptx(_n: u32) -> String {
36    let kernel = PtxKernel::new("cast_f32_to_bf16")
37        .param(PtxType::U64, "src_ptr")
38        .param(PtxType::U64, "dst_ptr")
39        .param(PtxType::U32, "n")
40        .build(|ctx| {
41            let ctaid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::CtaIdX);
42            let ntid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::NtidX);
43            let tid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::TidX);
44
45            let idx = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
46            let n_param = ctx.load_param_u32("n");
47            let pred = ctx.setp_ge_u32(idx, n_param);
48            ctx.branch_if(pred, "exit");
49
50            let src_ptr = ctx.load_param_u64("src_ptr");
51            let dst_ptr = ctx.load_param_u64("dst_ptr");
52
53            // Load f32 as raw u32 bits: src_ptr + idx * 4
54            let offset = ctx.mul_wide_u32(idx, 4);
55            let addr = ctx.add_u64(src_ptr, offset);
56            let bits = ctx.ld_global_u32(addr);
57
58            // Right-shift by 16 to get upper 16 bits (bf16 = truncated f32)
59            let bf16_bits = ctx.shr_u32_imm(bits, 16);
60
61            // Store as u16: dst_ptr + idx * 2
62            let dst_offset = ctx.mul_wide_u32(idx, 2);
63            let dst_addr = ctx.add_u64(dst_ptr, dst_offset);
64            ctx.st_global_u16(dst_addr, bf16_bits);
65
66            ctx.label("exit");
67            ctx.ret();
68        });
69    PtxModule::new().target("sm_70").add_kernel(kernel).emit()
70}
71
72/// Build PTX kernel for bf16 → f32 cast.
73///
74/// Each thread converts one element: loads bf16 as u16, left-shifts to upper 16 bits
75/// of a u32 (zero-extending the mantissa), stores as u32 (f32 bits).
76#[cfg(feature = "cuda")]
77fn build_cast_bf16_to_f32_ptx(_n: u32) -> String {
78    let kernel = PtxKernel::new("cast_bf16_to_f32")
79        .param(PtxType::U64, "src_ptr")
80        .param(PtxType::U64, "dst_ptr")
81        .param(PtxType::U32, "n")
82        .build(|ctx| {
83            let ctaid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::CtaIdX);
84            let ntid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::NtidX);
85            let tid_x = ctx.special_reg(trueno_gpu::ptx::PtxReg::TidX);
86
87            let idx = ctx.mad_lo_u32(ctaid_x, ntid_x, tid_x);
88            let n_param = ctx.load_param_u32("n");
89            let pred = ctx.setp_ge_u32(idx, n_param);
90            ctx.branch_if(pred, "exit");
91
92            let src_ptr = ctx.load_param_u64("src_ptr");
93            let dst_ptr = ctx.load_param_u64("dst_ptr");
94
95            // Load bf16 as u16: src_ptr + idx * 2
96            let src_offset = ctx.mul_wide_u32(idx, 2);
97            let src_addr = ctx.add_u64(src_ptr, src_offset);
98            let bf16_bits = ctx.ld_global_u16(src_addr);
99
100            // Left-shift by 16 to place in upper 16 bits (zero-extend mantissa)
101            let f32_bits = ctx.shl_u32_imm(bf16_bits, 16);
102
103            // Store as u32 (which is f32 bits): dst_ptr + idx * 4
104            let dst_offset = ctx.mul_wide_u32(idx, 4);
105            let dst_addr = ctx.add_u64(dst_ptr, dst_offset);
106            ctx.st_global_u32(dst_addr, f32_bits);
107
108            ctx.label("exit");
109            ctx.ret();
110        });
111    PtxModule::new().target("sm_70").add_kernel(kernel).emit()
112}
113
114/// Cast f32 GPU buffer to bf16 on GPU.
115///
116/// # Contract (C-BF16CAST-001)
117///
118/// - **Precondition**: `src.len() >= n`, `dst.len() >= n`, `n > 0`
119/// - **Postcondition**: `dst[i]` contains bf16 representation of `src[i]` (truncated mantissa)
120/// - **Invariant**: No CPU-side data transfers
121#[cfg(feature = "cuda")]
122pub fn cast_f32_to_bf16_gpu(
123    src: &GpuBuffer<f32>,
124    dst: &mut GpuBuffer<u16>,
125    n: u32,
126    stream: &CudaStream,
127) -> Result<()> {
128    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
129    let mut cache = cache.lock().map_err(|_err| {
130        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
131    })?;
132
133    let key = "cast_f32_to_bf16";
134    let module = match cache.get_cached(key) {
135        Some(m) => m,
136        None => {
137            let ptx = build_cast_f32_to_bf16_ptx(0);
138            cache.get_or_compile(key, &ptx)?
139        }
140    };
141
142    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
143
144    let src_ptr = src.as_ptr();
145    let dst_ptr = dst.as_ptr();
146
147    let mut args: [*mut std::ffi::c_void; 3] =
148        [&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
149
150    // SAFETY: Kernel launch requires FFI. src and dst are valid GPU allocations,
151    // src has n*4 bytes readable, dst has n*2 bytes writable.
152    unsafe {
153        stream.launch_kernel(module, "cast_f32_to_bf16", &config, &mut args).map_err(|e| {
154            CudaTensorError::KernelError(format!("cast_f32_to_bf16 launch failed: {e:?}"))
155        })?;
156    }
157
158    Ok(())
159}
160
161/// Cast bf16 GPU buffer to f32 on GPU.
162///
163/// # Contract (C-BF16CAST-001)
164///
165/// - **Precondition**: `src.len() >= n`, `dst.len() >= n`, `n > 0`
166/// - **Postcondition**: `dst[i]` contains f32 representation of `src[i]` (zero-extended mantissa)
167/// - **Invariant**: No CPU-side data transfers
168#[cfg(feature = "cuda")]
169pub fn cast_bf16_to_f32_gpu(
170    src: &GpuBuffer<u16>,
171    dst: &mut GpuBuffer<f32>,
172    n: u32,
173    stream: &CudaStream,
174) -> Result<()> {
175    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
176    let mut cache = cache.lock().map_err(|_err| {
177        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
178    })?;
179
180    let key = "cast_bf16_to_f32";
181    let module = match cache.get_cached(key) {
182        Some(m) => m,
183        None => {
184            let ptx = build_cast_bf16_to_f32_ptx(0);
185            cache.get_or_compile(key, &ptx)?
186        }
187    };
188
189    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
190
191    let src_ptr = src.as_ptr();
192    let dst_ptr = dst.as_ptr();
193
194    let mut args: [*mut std::ffi::c_void; 3] =
195        [&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
196
197    // SAFETY: src and dst are valid GPU allocations, src has n*2 bytes, dst has n*4 bytes.
198    unsafe {
199        stream.launch_kernel(module, "cast_bf16_to_f32", &config, &mut args).map_err(|e| {
200            CudaTensorError::KernelError(format!("cast_bf16_to_f32 launch failed: {e:?}"))
201        })?;
202    }
203
204    Ok(())
205}
206
207/// CPU-side f32 to bf16 conversion for a slice (uses `half` crate).
208///
209/// Useful for pre-converting weights before GPU upload.
210pub fn f32_slice_to_bf16(src: &[f32]) -> Vec<half::bf16> {
211    src.iter().map(|&v| half::bf16::from_f32(v)).collect()
212}
213
214/// CPU-side bf16 to f32 conversion for a slice (uses `half` crate).
215pub fn bf16_slice_to_f32(src: &[half::bf16]) -> Vec<f32> {
216    src.iter().map(|v| v.to_f32()).collect()
217}
218
219/// GPU f32 → f16 cast using trueno's CastF32ToF16Kernel (PTX cvt.rn.f16.f32).
220///
221/// Enables FP16 GEMM dispatch by casting fp32 activations to fp16.
222/// Overhead: ~0.02ms for 512×1536 elements at 256 GB/s — negligible vs GEMM savings.
223#[cfg(feature = "cuda")]
224pub fn cast_f32_to_f16_gpu(
225    src: &GpuBuffer<f32>,
226    dst: &mut GpuBuffer<u16>,
227    n: u32,
228    stream: &CudaStream,
229) -> Result<()> {
230    use trueno_gpu::kernels::{CastF32ToF16Kernel, Kernel};
231
232    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
233    let mut cache = cache.lock().map_err(|_err| {
234        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
235    })?;
236
237    let key = "cast_f32_to_f16";
238    let module = match cache.get_cached(key) {
239        Some(m) => m,
240        None => {
241            let kernel = CastF32ToF16Kernel;
242            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
243            cache.get_or_compile(key, &ptx)?
244        }
245    };
246
247    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
248
249    let src_ptr = src.as_ptr();
250    let dst_ptr = dst.as_ptr();
251
252    let mut args: [*mut std::ffi::c_void; 3] =
253        [&src_ptr as *const _ as *mut _, &dst_ptr as *const _ as *mut _, &n as *const _ as *mut _];
254
255    unsafe {
256        stream
257            .launch_kernel(module, "cast_f32_to_f16", &config, &mut args)
258            .map_err(|e| CudaTensorError::KernelError(format!("f32→f16 cast failed: {e:?}")))?;
259    }
260
261    Ok(())
262}