use std::sync::Arc;
use oxicuda::prelude::*;
use oxicuda::{DeviceBuffer, LaunchParams};
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
const TILE: u32 = 16;
fn build_sgemm_ptx() -> Result<String, oxicuda_ptx::error::PtxGenError> {
let tile = TILE;
KernelBuilder::new("sgemm_tiled")
.target(SmVersion::Sm80)
.param("a", PtxType::U64) .param("b", PtxType::U64) .param("c", PtxType::U64) .param("m", PtxType::U32) .param("n", PtxType::U32) .param("k", PtxType::U32) .shared_mem("smem_a", PtxType::F32, (tile * tile) as usize)
.shared_mem("smem_b", PtxType::F32, (tile * tile) as usize)
.max_threads_per_block(tile * tile)
.body(move |b| {
b.comment("=== sgemm_tiled: one thread → one C element ===");
let m_reg = b.load_param_u32("m");
let n_reg = b.load_param_u32("n");
let k_reg = b.load_param_u32("k");
let a_base = b.load_param_u64("a");
let b_base = b.load_param_u64("b");
let c_base = b.load_param_u64("c");
let row = b.global_thread_id_y();
let col = b.global_thread_id_x();
let tx = b.thread_id_x();
let ty_local = b.mov_imm_u32(0);
b.raw_ptx(&format!("mov.u32 {ty_local}, %tid.y;"));
let acc = b.mov_imm_u32(0); b.raw_ptx(&format!("mov.f32 {acc}, 0f00000000;"));
let tile_col = b.mov_imm_u32(0); b.raw_ptx("TILE_LOOP:");
b.comment("--- load A[row][tile_col + tx] into smem_a[ty_local][tx] ---");
let a_inner = {
let tc_plus_tx = b.add_u32(tile_col.clone(), tx.clone());
b.mad_lo_u32(row.clone(), k_reg.clone(), tc_plus_tx)
};
let a_elem_addr = b.f32_elem_addr(a_base.clone(), a_inner);
let a_val = b.load_global_f32(a_elem_addr);
let smem_a_flat = {
let tile_imm = b.mov_imm_u32(tile);
b.mad_lo_u32(ty_local.clone(), tile_imm, tx.clone())
};
let smem_a_ptr = {
let ptr = b.mov_imm_u32(0);
b.raw_ptx(&format!("cvta.to.shared.u64 {ptr}, smem_a;"));
b.f32_elem_addr(ptr, smem_a_flat)
};
b.store_shared_f32(smem_a_ptr, a_val);
b.comment("--- load B[tile_col + ty_local][col] into smem_b[ty_local][tx] ---");
let b_inner = {
let tc_plus_ty = b.add_u32(tile_col.clone(), ty_local.clone());
b.mad_lo_u32(tc_plus_ty, n_reg.clone(), col.clone())
};
let b_elem_addr = b.f32_elem_addr(b_base.clone(), b_inner);
let b_val = b.load_global_f32(b_elem_addr);
let smem_b_flat = {
let tile_imm = b.mov_imm_u32(tile);
b.mad_lo_u32(ty_local.clone(), tile_imm, tx.clone())
};
let smem_b_ptr = {
let ptr = b.mov_imm_u32(0);
b.raw_ptx(&format!("cvta.to.shared.u64 {ptr}, smem_b;"));
b.f32_elem_addr(ptr, smem_b_flat)
};
b.store_shared_f32(smem_b_ptr, b_val);
b.bar_sync(0);
b.comment("--- inner product over the tile (unrolled TILE iterations) ---");
b.unroll(tile, |b, ki| {
let sa_idx = {
let tile_imm = b.mov_imm_u32(tile);
let ki_imm = b.mov_imm_u32(ki);
b.mad_lo_u32(ty_local.clone(), tile_imm, ki_imm)
};
let sa_ptr = {
let ptr = b.mov_imm_u32(0);
b.raw_ptx(&format!("cvta.to.shared.u64 {ptr}, smem_a;"));
b.f32_elem_addr(ptr, sa_idx)
};
let av = b.load_shared_f32(sa_ptr);
let sb_idx = {
let ki_imm = b.mov_imm_u32(ki);
let tile_imm = b.mov_imm_u32(tile);
b.mad_lo_u32(ki_imm, tile_imm, tx.clone())
};
let sb_ptr = {
let ptr = b.mov_imm_u32(0);
b.raw_ptx(&format!("cvta.to.shared.u64 {ptr}, smem_b;"));
b.f32_elem_addr(ptr, sb_idx)
};
let bv = b.load_shared_f32(sb_ptr);
let new_acc = b.fma_f32(av, bv, acc.clone());
b.raw_ptx(&format!("mov.f32 {acc}, {new_acc};"));
});
b.bar_sync(0);
let next_col = {
let step = b.mov_imm_u32(tile);
b.add_u32(tile_col.clone(), step)
};
b.raw_ptx(&format!("mov.u32 {tile_col}, {next_col};"));
b.raw_ptx(&format!("setp.lt.u32 %p_tl, {tile_col}, {k_reg};"));
b.raw_ptx("@%p_tl bra TILE_LOOP;");
b.if_lt_u32(row.clone(), m_reg.clone(), |b| {
b.if_lt_u32(col.clone(), n_reg.clone(), |b| {
let c_flat = b.mad_lo_u32(row.clone(), n_reg.clone(), col.clone());
let c_ptr = b.f32_elem_addr(c_base.clone(), c_flat);
b.store_global_f32(c_ptr, acc.clone());
});
});
b.ret();
})
.build()
}
fn cpu_gemm(a: &[f32], b: &[f32], c: &mut [f32], m: usize, n: usize, k: usize) {
for row in 0..m {
for col in 0..n {
let mut sum = 0.0f32;
for ki in 0..k {
sum += a[row * k + ki] * b[ki * n + col];
}
c[row * n + col] = sum;
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== OxiCUDA Matrix Multiplication (SGEMM) Example ===\n");
println!("Building PTX for sgemm_tiled (tile={}×{})...", TILE, TILE);
let ptx = build_sgemm_ptx()?;
println!("PTX generated successfully ({} bytes).", ptx.len());
println!(
"Shared memory per block: {} bytes (2 × {}×{} f32 tiles)\n",
2 * TILE * TILE * 4,
TILE,
TILE
);
let preview_len = ptx.len().min(800);
println!(
"--- PTX preview (first {} of {} chars) ---",
preview_len,
ptx.len()
);
println!("{}", &ptx[..preview_len]);
if ptx.len() > preview_len {
println!("... ({} more chars)", ptx.len() - preview_len);
}
println!("---\n");
match try_gpu_gemm(&ptx) {
Ok(()) => println!("\nGPU GEMM completed successfully."),
Err(e) => println!(
"\nGPU not available: {} (expected on macOS / no-GPU systems)",
e
),
}
Ok(())
}
fn try_gpu_gemm(ptx: &str) -> Result<(), Box<dyn std::error::Error>> {
oxicuda::init()?;
let device = Device::get(0)?;
println!("Using GPU: {}", device.name()?);
let (maj, min) = device.compute_capability()?;
println!(" Compute capability: {maj}.{min}");
println!(
" Total memory: {} MiB",
device.total_memory()? / (1024 * 1024)
);
let ctx = Arc::new(Context::new(&device)?);
let stream = Stream::new(&ctx)?;
let m: usize = 32;
let n: usize = 32;
let k: usize = 32;
let host_a: Vec<f32> = (0..m * k).map(|i| (i % 7) as f32 * 0.1).collect();
let host_b: Vec<f32> = (0..k * n).map(|i| (i % 5) as f32 * 0.1).collect();
let mut host_c_gpu = vec![0.0f32; m * n];
let mut host_c_cpu = vec![0.0f32; m * n];
cpu_gemm(&host_a, &host_b, &mut host_c_cpu, m, n, k);
let dev_a = DeviceBuffer::<f32>::from_host(&host_a)?;
let dev_b = DeviceBuffer::<f32>::from_host(&host_b)?;
let dev_c = DeviceBuffer::<f32>::alloc(m * n)?;
let module = Arc::new(oxicuda::Module::from_ptx(ptx)?);
let kernel = oxicuda::Kernel::from_module(module, "sgemm_tiled")?;
let grid_x = (n as u32).div_ceil(TILE);
let grid_y = (m as u32).div_ceil(TILE);
let params = LaunchParams::builder()
.grid(oxicuda::Dim3::new(grid_x, grid_y, 1))
.block(oxicuda::Dim3::new(TILE, TILE, 1))
.build();
println!("\nLaunching sgemm_tiled: M={m}, N={n}, K={k}");
println!(" Grid: {grid_x}×{grid_y}, Block: {TILE}×{TILE}");
let args = (
dev_a.as_device_ptr(),
dev_b.as_device_ptr(),
dev_c.as_device_ptr(),
m as u32,
n as u32,
k as u32,
);
kernel.launch(¶ms, &stream, &args)?;
stream.synchronize()?;
dev_c.copy_to_host(&mut host_c_gpu)?;
let mut max_err = 0.0f32;
let mut mismatches = 0usize;
for (i, (&gv, &cv)) in host_c_gpu.iter().zip(host_c_cpu.iter()).enumerate() {
let err = (gv - cv).abs();
if err > max_err {
max_err = err;
}
if err > 1e-3 {
mismatches += 1;
if mismatches <= 3 {
eprintln!(" Mismatch [{i}]: GPU={gv:.6}, CPU={cv:.6}");
}
}
}
if mismatches == 0 {
println!("SUCCESS: All {m}×{n} elements correct (max_err={max_err:.2e})");
} else {
eprintln!("FAILED: {mismatches} mismatches (max_err={max_err:.2e})");
}
Ok(())
}