use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
fn generate_relu_kernel() -> Result<String, oxicuda_ptx::error::PtxGenError> {
KernelBuilder::new("relu_f32")
.target(SmVersion::Sm80)
.param("x", PtxType::U64) .param("y", PtxType::U64) .param("n", PtxType::U32) .max_threads_per_block(256)
.body(|b| {
let gid = b.global_thread_id_x();
let n = b.load_param_u32("n");
b.if_lt_u32(gid.clone(), n, |b| {
let x_base = b.load_param_u64("x");
let y_base = b.load_param_u64("y");
let x_addr = b.f32_elem_addr(x_base, gid.clone());
let xi = b.load_global_f32(x_addr);
let zero = b.mov_imm_u32(0); b.raw_ptx(&format!("mov.f32 {zero}, 0f00000000;"));
let yi = b.max_f32(xi, zero);
let y_addr = b.f32_elem_addr(y_base, gid.clone());
b.store_global_f32(y_addr, yi);
});
b.ret();
})
.build()
}
fn generate_sigmoid_kernel() -> Result<String, oxicuda_ptx::error::PtxGenError> {
let log2e: f32 = std::f32::consts::LOG2_E;
let log2e_bits = log2e.to_bits();
let one_bits: u32 = 1.0f32.to_bits();
KernelBuilder::new("sigmoid_f32")
.target(SmVersion::Sm80)
.param("x", PtxType::U64)
.param("y", PtxType::U64)
.param("n", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let gid = b.global_thread_id_x();
let n = b.load_param_u32("n");
b.if_lt_u32(gid.clone(), n, |b| {
let x_base = b.load_param_u64("x");
let y_base = b.load_param_u64("y");
let x_addr = b.f32_elem_addr(x_base, gid.clone());
let xi = b.load_global_f32(x_addr);
let neg_x = b.neg_f32(xi);
let neg_x_log2e = b.mov_imm_u32(0);
b.raw_ptx(&format!(
"mul.f32 {neg_x_log2e}, {neg_x}, 0f{log2e_bits:08X};"
));
let ex2_val = b.ex2_approx_f32(neg_x_log2e);
let one_plus = b.mov_imm_u32(0);
b.raw_ptx(&format!("add.f32 {one_plus}, {ex2_val}, 0f{one_bits:08X};"));
let sig = b.rcp_approx_f32(one_plus);
let y_addr = b.f32_elem_addr(y_base, gid.clone());
b.store_global_f32(y_addr, sig);
});
b.ret();
})
.build()
}
fn generate_reduction_kernel() -> Result<String, oxicuda_ptx::error::PtxGenError> {
const BLOCK: u32 = 256;
KernelBuilder::new("sum_reduce_f32")
.target(SmVersion::Sm80)
.param("x", PtxType::U64) .param("result", PtxType::U64) .param("n", PtxType::U32) .shared_mem("smem", PtxType::F32, BLOCK as usize)
.max_threads_per_block(BLOCK)
.body(|b| {
let gid = b.global_thread_id_x();
let tid = b.thread_id_x();
let n = b.load_param_u32("n");
let x_base = b.load_param_u64("x");
let smem_tid_addr = {
let ptr = b.mov_imm_u32(0);
b.raw_ptx(&format!("cvta.to.shared.u64 {ptr}, smem;"));
b.f32_elem_addr(ptr, tid.clone())
};
let zero_f32 = b.mov_imm_u32(0);
b.raw_ptx(&format!("mov.f32 {zero_f32}, 0f00000000;"));
b.store_shared_f32(smem_tid_addr.clone(), zero_f32);
b.if_lt_u32(gid.clone(), n.clone(), |b| {
let x_addr = b.f32_elem_addr(x_base, gid.clone());
let xv = b.load_global_f32(x_addr);
b.store_shared_f32(smem_tid_addr.clone(), xv);
});
b.bar_sync(0);
b.comment("tree reduction: strides 128, 64, 32, 16, 8, 4, 2, 1");
for stride in [128u32, 64, 32, 16, 8, 4, 2, 1] {
let stride_upper = b.mov_imm_u32(stride);
b.if_lt_u32(tid.clone(), stride_upper, |b| {
let stride_reg = b.mov_imm_u32(stride);
let partner_idx = b.add_u32(tid.clone(), stride_reg);
let partner_addr = {
let ptr = b.mov_imm_u32(0);
b.raw_ptx(&format!("cvta.to.shared.u64 {ptr}, smem;"));
b.f32_elem_addr(ptr, partner_idx)
};
let pv = b.load_shared_f32(partner_addr);
let sv = b.load_shared_f32(smem_tid_addr.clone());
let sum = b.add_f32(sv, pv);
b.store_shared_f32(smem_tid_addr.clone(), sum);
});
b.bar_sync(0);
}
let one_bound = b.mov_imm_u32(1);
b.if_lt_u32(tid.clone(), one_bound, |b| {
let block_sum = b.load_shared_f32(smem_tid_addr.clone());
let res_base = b.load_param_u64("result");
let _old = b.atom_global_add_f32(res_base, block_sum);
});
b.ret();
})
.build()
}
fn try_gpu_relu(relu_ptx: &str) -> Result<(), Box<dyn std::error::Error>> {
use oxicuda::prelude::*;
use oxicuda::{DeviceBuffer, LaunchParams};
use std::sync::Arc;
oxicuda::init()?;
let device = Device::get(0)?;
let (maj, min) = device.compute_capability()?;
println!(" GPU: {} ({maj}.{min})", device.name()?);
let ctx = Arc::new(Context::new(&device)?);
let stream = Stream::new(&ctx)?;
let n: u32 = 1024;
let host_x: Vec<f32> = (0..n).map(|i| (i as f32) - (n / 2) as f32).collect();
let mut host_y = vec![0.0f32; n as usize];
let dev_x = DeviceBuffer::<f32>::from_host(&host_x)?;
let dev_y = DeviceBuffer::<f32>::alloc(n as usize)?;
let module = Arc::new(oxicuda::Module::from_ptx(relu_ptx)?);
let kernel = oxicuda::Kernel::from_module(module, "relu_f32")?;
let block = 256u32;
let grid = n.div_ceil(block);
let params = LaunchParams::new(grid, block);
let args = (dev_x.as_device_ptr(), dev_y.as_device_ptr(), n);
kernel.launch(¶ms, &stream, &args)?;
stream.synchronize()?;
dev_y.copy_to_host(&mut host_y)?;
let mismatches = host_x
.iter()
.zip(host_y.iter())
.filter(|&(&x, &y)| {
let expected = x.max(0.0);
(y - expected).abs() > 1e-6
})
.count();
if mismatches == 0 {
println!(" ReLU GPU kernel: PASSED ({n} elements)");
} else {
eprintln!(" ReLU GPU kernel: FAILED ({mismatches} mismatches)");
}
Ok(())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== OxiCUDA PTX Generation Demo ===\n");
println!("--- 1. ReLU (elementwise max(0, x)) ---");
let relu_ptx = generate_relu_kernel()?;
println!("Generated {} bytes of PTX.", relu_ptx.len());
let preview = relu_ptx.len().min(500);
println!("{}", &relu_ptx[..preview]);
if relu_ptx.len() > preview {
println!("... ({} more chars)", relu_ptx.len() - preview);
}
println!();
println!("--- 2. Sigmoid (1 / (1 + exp(−x))) ---");
let sigmoid_ptx = generate_sigmoid_kernel()?;
println!("Generated {} bytes of PTX.\n", sigmoid_ptx.len());
println!("--- 3. Parallel sum reduction (block-level) ---");
let reduce_ptx = generate_reduction_kernel()?;
println!("Generated {} bytes of PTX.\n", reduce_ptx.len());
println!("Kernels generated:");
println!(" relu_f32 : {} bytes", relu_ptx.len());
println!(" sigmoid_f32 : {} bytes", sigmoid_ptx.len());
println!(" sum_reduce_f32 : {} bytes", reduce_ptx.len());
println!();
println!("--- 4. Attempting GPU launch of relu_f32 ---");
match try_gpu_relu(&relu_ptx) {
Ok(()) => println!("GPU launch succeeded!"),
Err(e) => println!(
"GPU not available: {} (expected on macOS / no-GPU systems)",
e
),
}
Ok(())
}