use std::ffi::c_void;
use std::sync::Arc;
use oxicuda_driver::{
Context, CudaResult, Device, JitOptions, Module, Stream, ffi::CUdeviceptr, loader::try_driver,
};
fn ptx_version_for_cc(major: i32, minor: i32) -> (u32, u32, String) {
let sm_str = format!("sm_{major}{minor}");
let (pm, pn) = match (major, minor) {
(7, 5) => (7, 4),
(8, 0) | (8, 6) => (7, 5),
(8, 9) => (8, 0),
(9, 0) | (9, 1) => (8, 0),
(10, _) => (8, 5),
(12, _) => (8, 7),
_ => (7, 4), };
(pm, pn, sm_str)
}
fn make_vector_add_ptx(cc_major: i32, cc_minor: i32) -> String {
let (ptx_major, ptx_minor, sm_str) = ptx_version_for_cc(cc_major, cc_minor);
format!(
r#".version {ptx_major}.{ptx_minor}
.target {sm_str}
.address_size 64
// vector_add: C[tid] = A[tid] + B[tid] (f32, one thread per element)
.visible .entry vector_add(
.param .u64 param_a,
.param .u64 param_b,
.param .u64 param_c,
.param .u32 param_n
)
{{
.reg .u32 %r<4>;
.reg .u64 %rd<5>;
.reg .f32 %f<3>;
.reg .pred %p0;
// tid = blockIdx.x * blockDim.x + threadIdx.x
mov.u32 %r0, %tid.x;
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %ctaid.x;
mad.lo.u32 %r0, %r2, %r1, %r0;
// Bounds check.
ld.param.u32 %r3, [param_n];
setp.ge.u32 %p0, %r0, %r3;
@%p0 bra $exit;
// Load base pointers.
ld.param.u64 %rd0, [param_a];
ld.param.u64 %rd1, [param_b];
ld.param.u64 %rd2, [param_c];
// Byte offset = tid * sizeof(f32) = tid << 2.
cvt.u64.u32 %rd3, %r0;
shl.b64 %rd3, %rd3, 2;
// A[tid]
add.u64 %rd4, %rd0, %rd3;
ld.global.f32 %f0, [%rd4];
// B[tid]
add.u64 %rd4, %rd1, %rd3;
ld.global.f32 %f1, [%rd4];
// C[tid] = A[tid] + B[tid]
add.f32 %f2, %f0, %f1;
add.u64 %rd4, %rd2, %rd3;
st.global.f32 [%rd4], %f2;
$exit:
ret;
}}
"#
)
}
fn device_alloc<T>(count: usize) -> CudaResult<CUdeviceptr> {
let api = try_driver()?;
let bytes = count * std::mem::size_of::<T>();
let mut ptr: CUdeviceptr = 0;
oxicuda_driver::cuda_call!((api.cu_mem_alloc_v2)(&mut ptr, bytes))?;
Ok(ptr)
}
fn htod<T>(dst: CUdeviceptr, src: &[T]) -> CudaResult<()> {
let api = try_driver()?;
let bytes = std::mem::size_of_val(src);
oxicuda_driver::cuda_call!((api.cu_memcpy_htod_v2)(
dst,
src.as_ptr().cast::<c_void>(),
bytes
))
}
fn dtoh<T>(dst: &mut [T], src: CUdeviceptr) -> CudaResult<()> {
let api = try_driver()?;
let bytes = std::mem::size_of_val(dst);
oxicuda_driver::cuda_call!((api.cu_memcpy_dtoh_v2)(
dst.as_mut_ptr().cast::<c_void>(),
src,
bytes
))
}
fn device_free(ptr: CUdeviceptr) {
if let Ok(api) = try_driver() {
let _ = unsafe { (api.cu_mem_free_v2)(ptr) };
}
}
fn launch(
func: &oxicuda_driver::Function,
grid_x: u32,
block_x: u32,
stream: &Stream,
params: &[*mut c_void],
) -> CudaResult<()> {
let api = try_driver()?;
oxicuda_driver::cuda_call!((api.cu_launch_kernel)(
func.raw(),
grid_x,
1,
1,
block_x,
1,
1,
0, stream.raw(),
params.as_ptr() as *mut *mut c_void,
std::ptr::null_mut(), ))
}
fn main() {
if let Err(e) = oxicuda_driver::init() {
eprintln!("Failed to load CUDA driver: {e}");
eprintln!("Is an NVIDIA GPU driver installed?");
std::process::exit(1);
}
let dev = match Device::get(0) {
Ok(d) => d,
Err(e) => {
eprintln!("No GPU found: {e}");
std::process::exit(1);
}
};
let (cc_major, cc_minor) = dev.compute_capability().unwrap_or((7, 5));
println!(
"Device: {} (sm_{cc_major}{cc_minor})",
dev.name().unwrap_or_else(|_| "<unknown>".into())
);
let ctx = match Context::new(&dev) {
Ok(c) => Arc::new(c),
Err(e) => {
eprintln!("Failed to create CUDA context: {e}");
std::process::exit(1);
}
};
let ptx = make_vector_add_ptx(cc_major, cc_minor);
let opts = JitOptions {
optimization_level: 4,
target_from_context: true,
..Default::default()
};
let (module, log) = match Module::from_ptx_with_options(&ptx, &opts) {
Ok(r) => r,
Err(e) => {
eprintln!("PTX JIT compilation failed: {e}");
std::process::exit(1);
}
};
if !log.info.is_empty() {
println!("JIT info: {}", log.info);
}
if log.has_errors() {
eprintln!("JIT errors:\n{}", log.error);
for d in log.errors() {
eprintln!(
" [{sev}]{kern} {msg}",
sev = d.severity,
kern = d
.kernel
.as_deref()
.map(|k| format!(" ({k})"))
.unwrap_or_default(),
msg = d.message,
);
}
std::process::exit(1);
}
let func = match module.get_function("vector_add") {
Ok(f) => f,
Err(e) => {
eprintln!("Kernel lookup failed: {e}");
std::process::exit(1);
}
};
const N: usize = 1024;
const BLOCK: u32 = 256;
const GRID: u32 = (N as u32).div_ceil(BLOCK);
let h_a: Vec<f32> = (1..=N).map(|i| i as f32).collect();
let h_b: Vec<f32> = (1..=N).map(|i| (i * 2) as f32).collect();
let mut h_c = vec![0.0f32; N];
let d_a = match device_alloc::<f32>(N) {
Ok(p) => p,
Err(e) => {
eprintln!("Device alloc A: {e}");
std::process::exit(1);
}
};
let d_b = match device_alloc::<f32>(N) {
Ok(p) => p,
Err(e) => {
device_free(d_a);
eprintln!("Device alloc B: {e}");
std::process::exit(1);
}
};
let d_c = match device_alloc::<f32>(N) {
Ok(p) => p,
Err(e) => {
device_free(d_a);
device_free(d_b);
eprintln!("Device alloc C: {e}");
std::process::exit(1);
}
};
if let Err(e) = htod(d_a, &h_a) {
eprintln!("H2D A: {e}");
std::process::exit(1);
}
if let Err(e) = htod(d_b, &h_b) {
eprintln!("H2D B: {e}");
std::process::exit(1);
}
let stream = match Stream::new(&ctx) {
Ok(s) => s,
Err(e) => {
eprintln!("Stream: {e}");
std::process::exit(1);
}
};
let n_val: u32 = N as u32;
let mut arg_a = d_a;
let mut arg_b = d_b;
let mut arg_c = d_c;
let mut arg_n = n_val;
let params: [*mut c_void; 4] = [
(&raw mut arg_a) as *mut c_void,
(&raw mut arg_b) as *mut c_void,
(&raw mut arg_c) as *mut c_void,
(&raw mut arg_n) as *mut c_void,
];
if let Err(e) = launch(&func, GRID, BLOCK, &stream, ¶ms) {
eprintln!("Kernel launch: {e}");
std::process::exit(1);
}
if let Err(e) = stream.synchronize() {
eprintln!("Sync: {e}");
std::process::exit(1);
}
if let Err(e) = dtoh(&mut h_c, d_c) {
eprintln!("D2H: {e}");
std::process::exit(1);
}
let mut fail = 0usize;
for i in 0..N {
let expected = h_a[i] + h_b[i];
if (h_c[i] - expected).abs() > 1e-5 {
eprintln!("Mismatch [{i}]: got {}, expected {}", h_c[i], expected);
fail += 1;
}
}
for i in 0..4.min(N) {
let ok = (h_c[i] - (h_a[i] + h_b[i])).abs() < 1e-5;
println!(
"C[{i}] = {} + {} = {} {}",
h_a[i],
h_b[i],
h_c[i],
if ok { "\u{2713}" } else { "\u{2717}" }
);
}
if fail == 0 {
println!("All {N} elements verified.");
} else {
eprintln!("{fail} verification failure(s)!");
std::process::exit(1);
}
device_free(d_a);
device_free(d_b);
device_free(d_c);
}