use kaio_core::emit::{Emit, PtxWriter};
use kaio_core::instr::control::{CmpOp, ControlOp};
use kaio_core::instr::memory::MemoryOp;
use kaio_core::instr::special;
use kaio_core::instr::{ArithOp, MadMode};
use kaio_core::ir::{Operand, PtxInstruction, PtxKernel, PtxModule, PtxParam, RegisterAllocator};
use kaio_core::types::PtxType;
use cudarc::driver::PushKernelArg;
use kaio_runtime::{KaioDevice, LaunchConfig};
fn build_vector_add_module() -> PtxModule {
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new("vector_add");
kernel.add_param(PtxParam::pointer("a_ptr", PtxType::F32));
kernel.add_param(PtxParam::pointer("b_ptr", PtxType::F32));
kernel.add_param(PtxParam::pointer("c_ptr", PtxType::F32));
kernel.add_param(PtxParam::scalar("n", PtxType::U32));
let rd_a = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
dst: rd_a,
param_name: "a_ptr".to_string(),
ty: PtxType::U64,
}));
let rd_b = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
dst: rd_b,
param_name: "b_ptr".to_string(),
ty: PtxType::U64,
}));
let rd_c = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
dst: rd_c,
param_name: "c_ptr".to_string(),
ty: PtxType::U64,
}));
let r_n = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Memory(MemoryOp::LdParam {
dst: r_n,
param_name: "n".to_string(),
ty: PtxType::U32,
}));
let (r_ctaid, ctaid_instr) = special::ctaid_x(&mut alloc);
kernel.push(ctaid_instr);
let (r_ntid, ntid_instr) = special::ntid_x(&mut alloc);
kernel.push(ntid_instr);
let (r_tid, tid_instr) = special::tid_x(&mut alloc);
kernel.push(tid_instr);
let r_idx = alloc.alloc(PtxType::S32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: r_idx,
a: Operand::Reg(r_ctaid),
b: Operand::Reg(r_ntid),
c: Operand::Reg(r_tid),
ty: PtxType::S32,
mode: MadMode::Lo,
}));
let p_oob = alloc.alloc(PtxType::Pred);
kernel.push(PtxInstruction::Control(ControlOp::SetP {
dst: p_oob,
cmp_op: CmpOp::Ge,
lhs: Operand::Reg(r_idx),
rhs: Operand::Reg(r_n),
ty: PtxType::U32,
}));
kernel.push(PtxInstruction::Control(ControlOp::BraPred {
pred: p_oob,
target: "EXIT".to_string(),
negate: false,
}));
let rd_a_global = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Memory(MemoryOp::CvtaToGlobal {
dst: rd_a_global,
src: rd_a,
}));
let rd_offset = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Arith(ArithOp::MulWide {
dst: rd_offset,
lhs: Operand::Reg(r_idx),
rhs: Operand::ImmU32(4),
src_ty: PtxType::U32,
}));
let rd_a_addr = alloc.alloc(PtxType::S64);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: rd_a_addr,
lhs: Operand::Reg(rd_a_global),
rhs: Operand::Reg(rd_offset),
ty: PtxType::S64,
}));
let f_a = alloc.alloc(PtxType::F32);
kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
dst: f_a,
addr: rd_a_addr,
ty: PtxType::F32,
}));
let rd_b_global = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Memory(MemoryOp::CvtaToGlobal {
dst: rd_b_global,
src: rd_b,
}));
let rd_b_addr = alloc.alloc(PtxType::S64);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: rd_b_addr,
lhs: Operand::Reg(rd_b_global),
rhs: Operand::Reg(rd_offset),
ty: PtxType::S64,
}));
let f_b = alloc.alloc(PtxType::F32);
kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
dst: f_b,
addr: rd_b_addr,
ty: PtxType::F32,
}));
let f_c = alloc.alloc(PtxType::F32);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: f_c,
lhs: Operand::Reg(f_a),
rhs: Operand::Reg(f_b),
ty: PtxType::F32,
}));
let rd_c_global = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Memory(MemoryOp::CvtaToGlobal {
dst: rd_c_global,
src: rd_c,
}));
let rd_c_addr = alloc.alloc(PtxType::S64);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: rd_c_addr,
lhs: Operand::Reg(rd_c_global),
rhs: Operand::Reg(rd_offset),
ty: PtxType::S64,
}));
kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
addr: rd_c_addr,
src: f_c,
ty: PtxType::F32,
}));
kernel.push(PtxInstruction::Label("EXIT".to_string()));
kernel.push(PtxInstruction::Control(ControlOp::Ret));
kernel.set_registers(alloc.into_allocated());
let sm = std::env::var("KAIO_SM_TARGET").unwrap_or_else(|_| "sm_70".to_string());
let mut module = PtxModule::new(&sm);
module.add_kernel(kernel);
module
}
fn emit_ptx_debug(module: &PtxModule) -> String {
let mut w = PtxWriter::new();
module.emit(&mut w).unwrap();
w.finish()
}
#[test]
#[ignore] fn vector_add_small() {
let ptx_module = build_vector_add_module();
let device = KaioDevice::new(0).expect("GPU required");
let module = device.load_module(&ptx_module).unwrap_or_else(|e| {
eprintln!(
"=== PTX that failed to load ===\n{}",
emit_ptx_debug(&ptx_module)
);
panic!("load_module failed: {e}");
});
let func = module.function("vector_add").unwrap_or_else(|e| {
panic!("function('vector_add') failed: {e}");
});
let a_host = [1.0f32, 2.0, 3.0];
let b_host = [4.0f32, 5.0, 6.0];
let n: u32 = 3;
let buf_a = device.alloc_from(&a_host).expect("alloc a");
let buf_b = device.alloc_from(&b_host).expect("alloc b");
let mut buf_out = device.alloc_zeros::<f32>(n as usize).expect("alloc out");
let cfg = LaunchConfig::for_num_elems(n);
unsafe {
device
.stream()
.launch_builder(func.inner())
.arg(buf_a.inner())
.arg(buf_b.inner())
.arg(buf_out.inner_mut())
.arg(&n)
.launch(cfg)
}
.unwrap_or_else(|e| {
eprintln!("=== PTX ===\n{}", emit_ptx_debug(&ptx_module));
panic!("kernel launch failed: {e}");
});
let result = buf_out.to_host(&device).expect("to_host");
assert_eq!(
result,
vec![5.0f32, 7.0, 9.0],
"vector_add produced wrong results"
);
}
#[test]
#[ignore] fn vector_add_large() {
let ptx_module = build_vector_add_module();
let device = KaioDevice::new(0).expect("GPU required");
let module = device.load_module(&ptx_module).unwrap_or_else(|e| {
eprintln!(
"=== PTX that failed to load ===\n{}",
emit_ptx_debug(&ptx_module)
);
panic!("load_module failed: {e}");
});
let func = module.function("vector_add").unwrap_or_else(|e| {
panic!("function('vector_add') failed: {e}");
});
let n: u32 = 10_000;
let a_host: Vec<f32> = (0..n).map(|i| i as f32).collect();
let b_host: Vec<f32> = (0..n).map(|i| (i * 2) as f32).collect();
let expected: Vec<f32> = a_host.iter().zip(&b_host).map(|(a, b)| a + b).collect();
let buf_a = device.alloc_from(&a_host).expect("alloc a");
let buf_b = device.alloc_from(&b_host).expect("alloc b");
let mut buf_out = device.alloc_zeros::<f32>(n as usize).expect("alloc out");
let cfg = LaunchConfig::for_num_elems(n);
unsafe {
device
.stream()
.launch_builder(func.inner())
.arg(buf_a.inner())
.arg(buf_b.inner())
.arg(buf_out.inner_mut())
.arg(&n)
.launch(cfg)
}
.unwrap_or_else(|e| {
eprintln!("=== PTX ===\n{}", emit_ptx_debug(&ptx_module));
panic!("kernel launch failed: {e}");
});
let result = buf_out.to_host(&device).expect("to_host");
assert_eq!(
result, expected,
"vector_add (10k elements) produced wrong results"
);
}