mod attention;
mod gemm;
use super::super::*;
pub(crate) fn cublas_gemm_threshold() -> u32 {
std::env::var("CUBLAS_GEMM_THRESHOLD")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(4)
}
const F32_TO_E4M3_PTX: &str = r#"
.version 7.5
.target sm_75
.address_size 64
.visible .entry f32_to_e4m3(
.param .u64 param_dst,
.param .u64 param_src,
.param .u32 param_count
) {
.reg .u64 %rd<5>;
.reg .u32 %r<16>;
.reg .f32 %f<4>;
.reg .pred %p<4>;
ld.param.u64 %rd0, [param_dst];
ld.param.u64 %rd1, [param_src];
ld.param.u32 %r0, [param_count];
// Global thread index
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mad.lo.u32 %r1, %r2, %r3, %r1;
// Bounds check
setp.ge.u32 %p0, %r1, %r0;
@%p0 bra L_DONE;
// Load FP32
cvt.u64.u32 %rd2, %r1;
shl.b64 %rd3, %rd2, 2;
add.u64 %rd3, %rd1, %rd3;
ld.global.f32 %f0, [%rd3];
// Reinterpret as u32
mov.b32 %r4, %f0;
// Extract sign bit (bit 31) -> %r5
shr.u32 %r5, %r4, 31;
// Extract FP32 exponent (bits 30:23) -> %r6
bfe.u32 %r6, %r4, 23, 8;
// Extract FP32 mantissa (bits 22:0) -> %r7
and.b32 %r7, %r4, 0x007FFFFF;
// Check for zero/denorm (exp == 0) -> output 0x00
setp.eq.u32 %p1, %r6, 0;
@%p1 bra L_ZERO;
// Check for NaN/Inf (exp == 255) -> output sign | 0x7E (max finite)
setp.eq.u32 %p2, %r6, 255;
@%p2 bra L_NANINF;
// Rebias exponent: e4m3_exp = fp32_exp - 127 + 7 = fp32_exp - 120
// E4M3 valid biased exp range: 1..15 (unbiased -6..+8)
sub.u32 %r8, %r6, 120;
// Check underflow (fp32_exp < 121 -> e4m3_exp < 1 -> denorm/zero)
setp.lt.s32 %p1, %r8, 1;
@%p1 bra L_ZERO;
// Check overflow (e4m3_exp > 15 -> clamp to max finite)
setp.gt.s32 %p2, %r8, 15;
@%p2 bra L_NANINF;
// Round mantissa: 23 bits -> 3 bits (drop 20 low bits, RNE)
// Round bit = bit 19, sticky = bits 18:0
shr.u32 %r9, %r7, 20; // top 3 mantissa bits
bfe.u32 %r10, %r7, 19, 1; // round bit
and.b32 %r11, %r7, 0x0007FFFF; // sticky bits (bits 18:0)
// RNE: round up if (round && (sticky || lsb_of_result))
and.b32 %r12, %r9, 1; // lsb of truncated mantissa
or.b32 %r13, %r11, %r12; // sticky | lsb
setp.ne.u32 %p3, %r13, 0;
and.b32 %r14, %r10, 1; // round bit
// Only round up if round bit is set AND (sticky|lsb)
selp.u32 %r14, %r14, 0, %p3;
add.u32 %r9, %r9, %r14;
// Handle mantissa overflow (0b1000 -> increment exponent)
setp.gt.u32 %p3, %r9, 7;
@!%p3 bra L_PACK;
mov.u32 %r9, 0;
add.u32 %r8, %r8, 1;
// If exponent overflows to 16 -> max finite
setp.gt.s32 %p2, %r8, 15;
@%p2 bra L_NANINF;
L_PACK:
// Pack: sign(1) | exp(4) | mantissa(3)
shl.b32 %r5, %r5, 7;
shl.b32 %r8, %r8, 3;
or.b32 %r5, %r5, %r8;
or.b32 %r5, %r5, %r9;
bra L_STORE;
L_ZERO:
mov.u32 %r5, 0;
bra L_STORE;
L_NANINF:
// sign | 0x7E (max finite = 448.0)
shl.b32 %r5, %r5, 7;
or.b32 %r5, %r5, 0x7E;
L_STORE:
// Store 1 byte
add.u64 %rd4, %rd0, %rd2;
st.global.u8 [%rd4], %r5;
L_DONE:
ret;
}
"#;
const F32_TO_E4M3_SCALED_PTX: &str = r#"
.version 7.5
.target sm_75
.address_size 64
.visible .entry f32_to_e4m3_scaled(
.param .u64 param_dst,
.param .u64 param_src,
.param .u32 param_count,
.param .f32 param_scale
) {
.reg .u64 %rd<5>;
.reg .u32 %r<16>;
.reg .f32 %f<4>;
.reg .pred %p<4>;
ld.param.u64 %rd0, [param_dst];
ld.param.u64 %rd1, [param_src];
ld.param.u32 %r0, [param_count];
ld.param.f32 %f3, [param_scale];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mad.lo.u32 %r1, %r2, %r3, %r1;
setp.ge.u32 %p0, %r1, %r0;
@%p0 bra L_DONE_S;
cvt.u64.u32 %rd2, %r1;
shl.b64 %rd3, %rd2, 2;
add.u64 %rd3, %rd1, %rd3;
ld.global.f32 %f0, [%rd3];
mul.f32 %f0, %f0, %f3;
mov.b32 %r4, %f0;
shr.u32 %r5, %r4, 31;
bfe.u32 %r6, %r4, 23, 8;
and.b32 %r7, %r4, 0x007FFFFF;
setp.eq.u32 %p1, %r6, 0;
@%p1 bra L_ZERO_S;
setp.eq.u32 %p2, %r6, 255;
@%p2 bra L_NANINF_S;
sub.u32 %r8, %r6, 120;
setp.lt.s32 %p1, %r8, 1;
@%p1 bra L_ZERO_S;
setp.gt.s32 %p2, %r8, 15;
@%p2 bra L_NANINF_S;
shr.u32 %r9, %r7, 20;
bfe.u32 %r10, %r7, 19, 1;
and.b32 %r11, %r7, 0x0007FFFF;
and.b32 %r12, %r9, 1;
or.b32 %r13, %r11, %r12;
setp.ne.u32 %p3, %r13, 0;
and.b32 %r14, %r10, 1;
selp.u32 %r14, %r14, 0, %p3;
add.u32 %r9, %r9, %r14;
setp.gt.u32 %p3, %r9, 7;
@!%p3 bra L_PACK_S;
mov.u32 %r9, 0;
add.u32 %r8, %r8, 1;
setp.gt.s32 %p2, %r8, 15;
@%p2 bra L_NANINF_S;
L_PACK_S:
shl.b32 %r5, %r5, 7;
shl.b32 %r8, %r8, 3;
or.b32 %r5, %r5, %r8;
or.b32 %r5, %r5, %r9;
bra L_STORE_S;
L_ZERO_S:
mov.u32 %r5, 0;
bra L_STORE_S;
L_NANINF_S:
shl.b32 %r5, %r5, 7;
or.b32 %r5, %r5, 0x7E;
L_STORE_S:
add.u64 %rd4, %rd0, %rd2;
st.global.u8 [%rd4], %r5;
L_DONE_S:
ret;
}
"#;
const F32_TO_E4M3_DEVICE_SCALED_PTX: &str = r#"
.version 7.5
.target sm_75
.address_size 64
.visible .entry f32_to_e4m3_device_scaled(
.param .u64 param_dst, // FP8 output buffer
.param .u64 param_src, // FP32 input buffer
.param .u32 param_count, // number of elements
.param .u64 param_absmax_ptr, // device ptr to u32 (absmax as IEEE 754 bits)
.param .u64 param_dequant_ptr // device ptr to f32 where to write absmax/448
) {
.reg .u64 %rd<6>;
.reg .u32 %r<16>;
.reg .f32 %f<7>;
.reg .pred %p<5>;
ld.param.u64 %rd0, [param_dst];
ld.param.u64 %rd1, [param_src];
ld.param.u32 %r0, [param_count];
ld.param.u64 %rd4, [param_absmax_ptr];
ld.param.u64 %rd5, [param_dequant_ptr];
// Read absmax from device memory (u32 bits to f32)
ld.global.u32 %r15, [%rd4];
mov.b32 %f4, %r15;
// Handle absmax == 0: use 1.0 to avoid div-by-zero
setp.eq.f32 %p3, %f4, 0f00000000;
@%p3 mov.f32 %f4, 0f3F800000;
// Compute quant_scale = 448 / absmax (f5 = 448.0, reused below)
mov.f32 %f5, 0f43E00000;
div.rn.f32 %f3, %f5, %f4;
// Thread 0 of block 0: write dequant_scale = absmax / 448 to device buffer
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
or.b32 %r14, %r1, %r2;
setp.ne.u32 %p4, %r14, 0;
@%p4 bra L_SKIP_DEQUANT;
div.rn.f32 %f6, %f4, %f5;
st.global.f32 [%rd5], %f6;
L_SKIP_DEQUANT:
// Grid-stride loop for FP8 conversion
mov.u32 %r3, %ntid.x;
mad.lo.u32 %r1, %r2, %r3, %r1;
setp.ge.u32 %p0, %r1, %r0;
@%p0 bra L_DONE_DS;
cvt.u64.u32 %rd2, %r1;
shl.b64 %rd3, %rd2, 2;
add.u64 %rd3, %rd1, %rd3;
ld.global.f32 %f0, [%rd3];
mul.f32 %f0, %f0, %f3;
// FP32 to FP8 E4M3 conversion (same logic as f32_to_e4m3_scaled)
mov.b32 %r4, %f0;
shr.u32 %r5, %r4, 31;
bfe.u32 %r6, %r4, 23, 8;
and.b32 %r7, %r4, 0x007FFFFF;
setp.eq.u32 %p1, %r6, 0;
@%p1 bra L_ZERO_DS;
setp.eq.u32 %p2, %r6, 255;
@%p2 bra L_NANINF_DS;
sub.u32 %r8, %r6, 120;
setp.lt.s32 %p1, %r8, 1;
@%p1 bra L_ZERO_DS;
setp.gt.s32 %p2, %r8, 15;
@%p2 bra L_NANINF_DS;
shr.u32 %r9, %r7, 20;
bfe.u32 %r10, %r7, 19, 1;
and.b32 %r11, %r7, 0x0007FFFF;
and.b32 %r12, %r9, 1;
or.b32 %r13, %r11, %r12;
setp.ne.u32 %p3, %r13, 0;
and.b32 %r14, %r10, 1;
selp.u32 %r14, %r14, 0, %p3;
add.u32 %r9, %r9, %r14;
setp.gt.u32 %p3, %r9, 7;
@!%p3 bra L_PACK_DS;
mov.u32 %r9, 0;
add.u32 %r8, %r8, 1;
setp.gt.s32 %p2, %r8, 15;
@%p2 bra L_NANINF_DS;
L_PACK_DS:
shl.b32 %r5, %r5, 7;
shl.b32 %r8, %r8, 3;
or.b32 %r5, %r5, %r8;
or.b32 %r5, %r5, %r9;
bra L_STORE_DS;
L_ZERO_DS:
mov.u32 %r5, 0;
bra L_STORE_DS;
L_NANINF_DS:
shl.b32 %r5, %r5, 7;
or.b32 %r5, %r5, 0x7E;
L_STORE_DS:
add.u64 %rd4, %rd0, %rd2;
st.global.u8 [%rd4], %r5;
L_DONE_DS:
ret;
}
"#;
const ABSMAX_REDUCE_PTX: &str = r#"
.version 7.5
.target sm_75
.address_size 64
.visible .entry absmax_reduce(
.param .u64 param_output,
.param .u64 param_input,
.param .u32 param_count
) {
.reg .u64 %rd<5>;
.reg .u32 %r<10>;
.reg .f32 %f<4>;
.reg .pred %p<2>;
.shared .align 4 .b32 sdata[256];
ld.param.u64 %rd0, [param_output];
ld.param.u64 %rd1, [param_input];
ld.param.u32 %r0, [param_count];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mad.lo.u32 %r4, %r2, %r3, %r1;
mov.u32 %r5, %nctaid.x;
mul.lo.u32 %r5, %r5, %r3;
mov.b32 %f0, 0x00000000;
L_LOOP:
setp.ge.u32 %p0, %r4, %r0;
@%p0 bra L_REDUCE;
cvt.u64.u32 %rd2, %r4;
shl.b64 %rd3, %rd2, 2;
add.u64 %rd3, %rd1, %rd3;
ld.global.f32 %f1, [%rd3];
abs.f32 %f1, %f1;
max.f32 %f0, %f0, %f1;
add.u32 %r4, %r4, %r5;
bra L_LOOP;
L_REDUCE:
// Store thread-local absmax into shared memory
// %r7 = &sdata[tid.x] (reused as this thread's shared addr)
mov.u32 %r6, sdata;
shl.b32 %r9, %r1, 2;
add.u32 %r7, %r6, %r9;
st.shared.f32 [%r7], %f0;
bar.sync 0;
// Tree reduction in shared memory
mov.u32 %r8, 128;
L_RED_LOOP:
setp.ge.u32 %p1, %r1, %r8;
@%p1 bra L_RED_DONE;
add.u32 %r9, %r1, %r8;
shl.b32 %r9, %r9, 2;
add.u32 %r9, %r6, %r9;
ld.shared.f32 %f2, [%r9];
max.f32 %f0, %f0, %f2;
st.shared.f32 [%r7], %f0;
L_RED_DONE:
bar.sync 0;
shr.u32 %r8, %r8, 1;
setp.ne.u32 %p1, %r8, 0;
@%p1 bra L_RED_LOOP;
// Thread 0 does atomic max into global output
setp.ne.u32 %p0, %r1, 0;
@%p0 bra L_EXIT;
mov.b32 %r9, %f0;
atom.global.max.u32 %r9, [%rd0], %r9;
L_EXIT:
ret;
}
"#;
const F32_TO_F16_PTX: &str = r#"
.version 7.5
.target sm_75
.address_size 64
.visible .entry f32_to_f16(
.param .u64 param_dst,
.param .u64 param_src,
.param .u32 param_count
) {
.reg .u64 %rd<5>;
.reg .u32 %r<4>;
.reg .f32 %f0;
.reg .b16 %h0;
.reg .pred %p0;
ld.param.u64 %rd0, [param_dst];
ld.param.u64 %rd1, [param_src];
ld.param.u32 %r0, [param_count];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mad.lo.u32 %r1, %r2, %r3, %r1;
setp.ge.u32 %p0, %r1, %r0;
@%p0 bra L_DONE;
cvt.u64.u32 %rd2, %r1;
shl.b64 %rd3, %rd2, 2;
add.u64 %rd3, %rd1, %rd3;
ld.global.f32 %f0, [%rd3];
cvt.rn.f16.f32 %h0, %f0;
shl.b64 %rd4, %rd2, 1;
add.u64 %rd4, %rd0, %rd4;
st.global.b16 [%rd4], %h0;
L_DONE:
ret;
}
"#;
const F16_TO_F32_PTX: &str = r#"
.version 7.5
.target sm_75
.address_size 64
.visible .entry f16_to_f32(
.param .u64 param_dst,
.param .u64 param_src,
.param .u32 param_count
) {
.reg .u64 %rd<5>;
.reg .u32 %r<4>;
.reg .f32 %f0;
.reg .b16 %h0;
.reg .pred %p0;
ld.param.u64 %rd0, [param_dst];
ld.param.u64 %rd1, [param_src];
ld.param.u32 %r0, [param_count];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mad.lo.u32 %r1, %r2, %r3, %r1;
setp.ge.u32 %p0, %r1, %r0;
@%p0 bra L_DONE;
// Load FP16 (16 bits)
cvt.u64.u32 %rd2, %r1;
shl.b64 %rd3, %rd2, 1;
add.u64 %rd3, %rd1, %rd3;
ld.global.b16 %h0, [%rd3];
// Convert FP16 -> FP32 (hardware instruction)
cvt.f32.f16 %f0, %h0;
// Store as FP32 (4 bytes)
shl.b64 %rd4, %rd2, 2;
add.u64 %rd4, %rd0, %rd4;
st.global.f32 [%rd4], %f0;
L_DONE:
ret;
}
"#;
const F16_TO_F32_ACT_SCALED_PTX: &str = r#"
.version 7.5
.target sm_75
.address_size 64
.visible .entry f16_to_f32_act_scaled(
.param .u64 param_dst,
.param .u64 param_src,
.param .u32 param_count,
.param .u64 param_act_dequant_ptr
) {
.reg .u64 %rd<6>;
.reg .u32 %r<4>;
.reg .f32 %f<2>;
.reg .b16 %h0;
.reg .pred %p0;
ld.param.u64 %rd0, [param_dst];
ld.param.u64 %rd1, [param_src];
ld.param.u32 %r0, [param_count];
ld.param.u64 %rd5, [param_act_dequant_ptr];
// Read act_dequant scale from device (same for all elements)
ld.global.f32 %f1, [%rd5];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mad.lo.u32 %r1, %r2, %r3, %r1;
setp.ge.u32 %p0, %r1, %r0;
@%p0 bra L_DONE_AS;
// Load FP16 (16 bits)
cvt.u64.u32 %rd2, %r1;
shl.b64 %rd3, %rd2, 1;
add.u64 %rd3, %rd1, %rd3;
ld.global.b16 %h0, [%rd3];
// Convert FP16 to FP32 then scale by act_dequant
cvt.f32.f16 %f0, %h0;
mul.f32 %f0, %f0, %f1;
// Store as FP32 (4 bytes)
shl.b64 %rd4, %rd2, 2;
add.u64 %rd4, %rd0, %rd4;
st.global.f32 [%rd4], %f0;
L_DONE_AS:
ret;
}
"#;
impl CudaExecutor {
pub(crate) fn ensure_cublas(&mut self) -> Result<(), GpuError> {
if self.cublas_handle.is_some() {
return Ok(());
}
let handle = trueno_gpu::driver::CublasHandle::new(&self.context)?;
handle.set_stream(&self.stream)?;
self.cublas_handle = Some(handle);
Ok(())
}
pub(crate) fn ensure_cublas_workspace(&mut self) -> Result<(), GpuError> {
if self.cublas_workspace.is_some() {
return Ok(());
}
self.ensure_cublas()?;
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024;
let workspace = GpuBuffer::<u8>::new(&self.context, WORKSPACE_SIZE)?;
let handle = self.cublas_handle.as_ref().expect("cublas initialized");
handle.set_workspace(workspace.as_ptr(), WORKSPACE_SIZE)?;
eprintln!(
"[PMAT-063] cuBLAS workspace: {} MB pre-allocated for graph capture",
WORKSPACE_SIZE / 1024 / 1024
);
self.cublas_workspace = Some(workspace);
Ok(())
}
fn ensure_dequant_scratch(&mut self, n: u32, k: u32) -> Result<(), GpuError> {
let needed = n as usize * k as usize;
if self.dequant_scratch_size >= needed {
return Ok(());
}
self.dequant_scratch = Some(GpuBuffer::new(&self.context, needed)?);
self.dequant_scratch_size = needed;
Ok(())
}
pub(crate) fn ensure_fp16_activation_scratch(&mut self, count: usize) -> Result<(), GpuError> {
if self.fp16_activation_scratch_size >= count {
return Ok(());
}
self.fp16_activation_scratch = Some(GpuBuffer::new(&self.context, count)?);
self.fp16_activation_scratch_size = count;
Ok(())
}
pub(crate) fn convert_f32_to_f16(
&mut self,
src_ptr: u64,
dst_ptr: u64,
count: u32,
) -> Result<(), GpuError> {
if !self.modules.contains_key("f32_to_f16") {
let module = self.compile_ptx(F32_TO_F16_PTX)?;
self.modules.insert("f32_to_f16".to_string(), module);
}
let module = self.modules.get_mut("f32_to_f16").expect("just inserted");
let config = LaunchConfig::linear(count, 256);
let mut dst = dst_ptr;
let mut src = src_ptr;
let mut cnt = count;
unsafe {
self.stream.launch_kernel(
module,
"f32_to_f16",
&config,
&mut [
std::ptr::from_mut(&mut dst) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut cnt) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
fn ensure_fp8_activation_scratch(&mut self, count: usize) -> Result<(), GpuError> {
if self.fp8_activation_scratch_size >= count {
return Ok(());
}
self.fp8_activation_scratch = Some(GpuBuffer::new(&self.context, count)?);
self.fp8_activation_scratch_size = count;
self.fp8_activation_cache_key = None;
Ok(())
}
fn convert_f32_to_e4m3(
&mut self,
src_ptr: u64,
dst_ptr: u64,
count: u32,
) -> Result<(), GpuError> {
if !self.modules.contains_key("f32_to_e4m3") {
let module = self.compile_ptx(F32_TO_E4M3_PTX)?;
self.modules.insert("f32_to_e4m3".to_string(), module);
}
let module = self.modules.get_mut("f32_to_e4m3").expect("just inserted");
let config = LaunchConfig::linear(count, 256);
let mut dst = dst_ptr;
let mut src = src_ptr;
let mut cnt = count;
unsafe {
self.stream.launch_kernel(
module,
"f32_to_e4m3",
&config,
&mut [
std::ptr::from_mut(&mut dst) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut cnt) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
fn convert_f16_to_f32(
&mut self,
src_ptr: u64,
dst_ptr: u64,
count: u32,
) -> Result<(), GpuError> {
if !self.modules.contains_key("f16_to_f32") {
let module = self.compile_ptx(F16_TO_F32_PTX)?;
self.modules.insert("f16_to_f32".to_string(), module);
}
let module = self.modules.get_mut("f16_to_f32").expect("just inserted");
let config = LaunchConfig::linear(count, 256);
let mut dst = dst_ptr;
let mut src = src_ptr;
let mut cnt = count;
unsafe {
self.stream.launch_kernel(
module,
"f16_to_f32",
&config,
&mut [
std::ptr::from_mut(&mut dst) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut cnt) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
fn convert_f16_to_f32_act_scaled(
&mut self,
src_ptr: u64,
dst_ptr: u64,
count: u32,
act_dequant_ptr: u64,
) -> Result<(), GpuError> {
let cache_key = "f16_to_f32_act_scaled";
if !self.modules.contains_key(cache_key) {
let module = self.compile_ptx(F16_TO_F32_ACT_SCALED_PTX)?;
self.modules.insert(cache_key.to_string(), module);
}
let module = self.modules.get_mut(cache_key).expect("just inserted");
let config = LaunchConfig::linear(count, 256);
let mut dst = dst_ptr;
let mut src = src_ptr;
let mut cnt = count;
let mut act_deq = act_dequant_ptr;
unsafe {
self.stream.launch_kernel(
module,
"f16_to_f32_act_scaled",
&config,
&mut [
std::ptr::from_mut(&mut dst) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut cnt) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut act_deq) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
fn convert_f32_to_e4m3_scaled(
&mut self,
src_ptr: u64,
dst_ptr: u64,
count: u32,
quant_scale: f32,
) -> Result<(), GpuError> {
let cache_key = "f32_to_e4m3_scaled";
if !self.modules.contains_key(cache_key) {
let module = self.compile_ptx(F32_TO_E4M3_SCALED_PTX)?;
self.modules.insert(cache_key.to_string(), module);
}
let module = self.modules.get_mut(cache_key).expect("just inserted");
let num_blocks = (count + 255) / 256;
let config = LaunchConfig {
grid: (num_blocks, 1, 1),
block: (256, 1, 1),
shared_mem: 0,
};
let mut dst = dst_ptr;
let mut src = src_ptr;
let mut cnt = count;
let mut scale = quant_scale;
unsafe {
self.stream.launch_kernel(
module,
"f32_to_e4m3_scaled",
&config,
&mut [
std::ptr::from_mut(&mut dst) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut cnt) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut scale) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
fn gpu_absmax(&mut self, src_ptr: u64, count: u32) -> Result<f32, GpuError> {
if !self.modules.contains_key("absmax_reduce") {
let module = self.compile_ptx(ABSMAX_REDUCE_PTX)?;
self.modules.insert("absmax_reduce".to_string(), module);
}
let mut result_buf = GpuBuffer::<u32>::new(&self.context, 1)?;
let result_ptr = result_buf.as_ptr();
result_buf.copy_from_host(&[0u32])?;
let module = self
.modules
.get_mut("absmax_reduce")
.expect("just inserted");
let num_blocks = ((count + 255) / 256).min(256);
let config = LaunchConfig {
grid: (num_blocks, 1, 1),
block: (256, 1, 1),
shared_mem: 256 * 4,
};
let mut out = result_ptr;
let mut src = src_ptr;
let mut cnt = count;
unsafe {
self.stream.launch_kernel(
module,
"absmax_reduce",
&config,
&mut [
std::ptr::from_mut(&mut out) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut cnt) as *mut std::ffi::c_void,
],
)?;
}
self.stream.synchronize()?;
let mut result_u32 = [0u32; 1];
result_buf.copy_to_host(&mut result_u32)?;
let absmax = f32::from_bits(result_u32[0]);
Ok(absmax)
}
fn gpu_absmax_device(&mut self, src_ptr: u64, count: u32) -> Result<u64, GpuError> {
if !self.modules.contains_key("absmax_reduce") {
let module = self.compile_ptx(ABSMAX_REDUCE_PTX)?;
self.modules.insert("absmax_reduce".to_string(), module);
}
if self.fp8_absmax_buf.is_none() {
self.fp8_absmax_buf = Some(GpuBuffer::<u32>::new(&self.context, 1)?);
}
let absmax_buf = self.fp8_absmax_buf.as_mut().expect("just allocated");
let result_ptr = absmax_buf.as_ptr();
let zero_val = [0u32; 1];
unsafe {
absmax_buf.copy_from_host_async(&zero_val, &self.stream)?;
}
let module = self
.modules
.get_mut("absmax_reduce")
.expect("just inserted");
let num_blocks = ((count + 255) / 256).min(256);
let config = LaunchConfig {
grid: (num_blocks, 1, 1),
block: (256, 1, 1),
shared_mem: 256 * 4,
};
let mut out = result_ptr;
let mut src = src_ptr;
let mut cnt = count;
unsafe {
self.stream.launch_kernel(
module,
"absmax_reduce",
&config,
&mut [
std::ptr::from_mut(&mut out) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut cnt) as *mut std::ffi::c_void,
],
)?;
}
Ok(result_ptr)
}
fn convert_f32_to_e4m3_device_scaled(
&mut self,
src_ptr: u64,
dst_ptr: u64,
count: u32,
absmax_ptr: u64,
dequant_ptr: u64,
) -> Result<(), GpuError> {
let cache_key = "f32_to_e4m3_device_scaled";
if !self.modules.contains_key(cache_key) {
let module = self.compile_ptx(F32_TO_E4M3_DEVICE_SCALED_PTX)?;
self.modules.insert(cache_key.to_string(), module);
}
let module = self.modules.get_mut(cache_key).expect("just inserted");
let num_blocks = (count + 255) / 256;
let config = LaunchConfig {
grid: (num_blocks, 1, 1),
block: (256, 1, 1),
shared_mem: 0,
};
let mut dst = dst_ptr;
let mut src = src_ptr;
let mut cnt = count;
let mut abs_ptr = absmax_ptr;
let mut deq_ptr = dequant_ptr;
unsafe {
self.stream.launch_kernel(
module,
"f32_to_e4m3_device_scaled",
&config,
&mut [
std::ptr::from_mut(&mut dst) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut src) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut cnt) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut abs_ptr) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut deq_ptr) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
}