pub mod builder;
mod emit;
mod instructions;
pub mod optimize;
pub mod parity;
mod registers;
mod types;
pub use builder::{
KernelBuilder, KernelBuilderCore, PtxArithmetic, PtxAtomic, PtxComparison, PtxControl,
PtxKernel, PtxMemory, PtxModule, PtxSync,
};
pub use instructions::{CmpOp, Operand, PtxInstruction, PtxOp, WmmaLayout, WmmaShape};
pub use registers::{LiveRange, PhysicalReg, PtxReg, RegisterAllocator, VirtualReg};
pub use types::{PtxStateSpace, PtxType};
use crate::error::{GpuError, Result};
pub const MIN_PTX_VERSION: (u32, u32) = (7, 0);
pub fn validate_version(major: u32, minor: u32) -> Result<()> {
if major < MIN_PTX_VERSION.0 || (major == MIN_PTX_VERSION.0 && minor < MIN_PTX_VERSION.1) {
return Err(GpuError::InvalidPtxVersion { major, minor });
}
Ok(())
}
pub fn validate_target(target: &str) -> Result<()> {
if !target.starts_with("sm_") {
return Err(GpuError::InvalidTarget(target.to_string()));
}
let version_str = &target[3..];
let version: u32 = version_str
.parse()
.map_err(|_| GpuError::InvalidTarget(target.to_string()))?;
if version < 70 {
return Err(GpuError::InvalidTarget(target.to_string()));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_version_valid() {
assert!(validate_version(7, 0).is_ok());
assert!(validate_version(8, 0).is_ok());
assert!(validate_version(8, 5).is_ok());
assert!(validate_version(9, 0).is_ok());
}
#[test]
fn test_validate_version_invalid() {
assert!(validate_version(6, 5).is_err());
assert!(validate_version(6, 0).is_err());
assert!(validate_version(5, 0).is_err());
}
#[test]
fn test_validate_target_valid() {
assert!(validate_target("sm_70").is_ok());
assert!(validate_target("sm_75").is_ok());
assert!(validate_target("sm_80").is_ok());
assert!(validate_target("sm_86").is_ok());
assert!(validate_target("sm_89").is_ok());
assert!(validate_target("sm_90").is_ok());
}
#[test]
fn test_validate_target_invalid() {
assert!(validate_target("sm_50").is_err());
assert!(validate_target("sm_60").is_err());
assert!(validate_target("sm_61").is_err());
assert!(validate_target("compute_70").is_err());
assert!(validate_target("70").is_err());
assert!(validate_target("").is_err());
}
#[test]
fn test_module_creation() {
let module = PtxModule::new();
assert_eq!(module.get_version(), (8, 0)); }
#[test]
fn test_module_version_builder() {
let module = PtxModule::new().version(8, 5);
assert_eq!(module.get_version(), (8, 5));
}
#[test]
fn test_module_target_builder() {
let module = PtxModule::new().target("sm_86");
assert_eq!(module.get_target(), "sm_86");
}
#[test]
fn test_module_address_size() {
let module = PtxModule::new().address_size(64);
assert_eq!(module.get_address_size(), 64);
}
#[test]
fn test_module_emit_header() {
let module = PtxModule::new()
.version(8, 0)
.target("sm_70")
.address_size(64);
let ptx = module.emit();
assert!(ptx.contains(".version 8.0"));
assert!(ptx.contains(".target sm_70"));
assert!(ptx.contains(".address_size 64"));
}
#[test]
fn test_module_emit_with_kernel() {
let kernel = PtxKernel::new("vector_add")
.param(PtxType::U64, "a_ptr")
.param(PtxType::U64, "b_ptr")
.param(PtxType::U64, "c_ptr")
.param(PtxType::U32, "n");
let module = PtxModule::new()
.version(8, 0)
.target("sm_70")
.address_size(64)
.add_kernel(kernel);
let ptx = module.emit();
assert!(ptx.contains(".visible .entry vector_add"));
assert!(ptx.contains(".param .u64 a_ptr"));
assert!(ptx.contains(".param .u64 b_ptr"));
assert!(ptx.contains(".param .u64 c_ptr"));
assert!(ptx.contains(".param .u32 n"));
}
#[test]
fn test_kernel_with_shared_memory() {
let kernel = PtxKernel::new("gemm_tiled").shared_memory(4096);
assert_eq!(kernel.shared_memory_bytes(), 4096);
}
#[test]
fn test_ptx_type_sizes() {
assert_eq!(PtxType::U8.size_bytes(), 1);
assert_eq!(PtxType::U16.size_bytes(), 2);
assert_eq!(PtxType::U32.size_bytes(), 4);
assert_eq!(PtxType::U64.size_bytes(), 8);
assert_eq!(PtxType::F16.size_bytes(), 2);
assert_eq!(PtxType::F32.size_bytes(), 4);
assert_eq!(PtxType::F64.size_bytes(), 8);
}
#[test]
fn test_ptx_type_to_string() {
assert_eq!(PtxType::U32.to_ptx_string(), ".u32");
assert_eq!(PtxType::U64.to_ptx_string(), ".u64");
assert_eq!(PtxType::F32.to_ptx_string(), ".f32");
assert_eq!(PtxType::F16.to_ptx_string(), ".f16");
assert_eq!(PtxType::Pred.to_ptx_string(), ".pred");
}
#[test]
fn test_special_registers() {
assert_eq!(PtxReg::TidX.to_ptx_string(), "%tid.x");
assert_eq!(PtxReg::TidY.to_ptx_string(), "%tid.y");
assert_eq!(PtxReg::TidZ.to_ptx_string(), "%tid.z");
assert_eq!(PtxReg::CtaIdX.to_ptx_string(), "%ctaid.x");
assert_eq!(PtxReg::NtidX.to_ptx_string(), "%ntid.x");
}
#[test]
fn test_virtual_register_allocation() {
let mut allocator = RegisterAllocator::new();
let r1 = allocator.allocate_virtual(PtxType::F32);
let r2 = allocator.allocate_virtual(PtxType::F32);
let r3 = allocator.allocate_virtual(PtxType::U32);
assert_ne!(r1.id(), r2.id());
assert_ne!(r2.id(), r3.id());
}
#[test]
fn test_register_pressure_tracking() {
let mut allocator = RegisterAllocator::new();
let _r1 = allocator.allocate_virtual(PtxType::F32);
let _r2 = allocator.allocate_virtual(PtxType::F32);
let _r3 = allocator.allocate_virtual(PtxType::F32);
let pressure = allocator.pressure_report();
assert_eq!(pressure.max_live, 3);
assert_eq!(pressure.spill_count, 0);
}
#[test]
fn test_emit_vector_add_kernel() {
let kernel = PtxKernel::new("vector_add")
.param(PtxType::U64, "a_ptr")
.param(PtxType::U64, "b_ptr")
.param(PtxType::U64, "c_ptr")
.param(PtxType::U32, "n")
.build(|ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let idx = ctx.mad_lo_u32(ctaid, ntid, tid);
let n = ctx.load_param_u32("n");
let pred = ctx.setp_ge_u32(idx, n);
ctx.branch_if(pred, "exit");
let a_ptr = ctx.load_param_u64("a_ptr");
let b_ptr = ctx.load_param_u64("b_ptr");
let c_ptr = ctx.load_param_u64("c_ptr");
let offset = ctx.mul_wide_u32(idx, 4);
let a_addr = ctx.add_u64(a_ptr, offset);
let b_addr = ctx.add_u64(b_ptr, offset);
let c_addr = ctx.add_u64(c_ptr, offset);
let a_val = ctx.ld_global_f32(a_addr);
let b_val = ctx.ld_global_f32(b_addr);
let c_val = ctx.add_f32(a_val, b_val);
ctx.st_global_f32(c_addr, c_val);
ctx.label("exit");
ctx.ret();
});
let module = PtxModule::new()
.version(8, 0)
.target("sm_70")
.address_size(64)
.add_kernel(kernel);
let ptx = module.emit();
assert!(ptx.contains("mov.u32"));
assert!(ptx.contains("%tid.x"));
assert!(ptx.contains("mad.lo"));
assert!(ptx.contains("setp.ge"));
assert!(ptx.contains("ld.global.f32"));
assert!(ptx.contains("add.f32"));
assert!(ptx.contains("st.global.f32"));
assert!(ptx.contains("ret;"));
}
}