use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
use super::Kernel;
pub const PAGE_SIZE: u32 = 4096;
pub const HASH_TABLE_SIZE: u32 = 8192; pub const SMEM_SIZE: usize = (PAGE_SIZE + HASH_TABLE_SIZE) as usize;
pub struct HashStoreTestKernel;
impl HashStoreTestKernel {
pub fn new() -> Self {
Self
}
}
impl Default for HashStoreTestKernel {
fn default() -> Self {
Self::new()
}
}
impl Kernel for HashStoreTestKernel {
fn name(&self) -> &str {
"hash_store_test"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new(self.name())
.param(PtxType::U64, "output")
.shared_memory(SMEM_SIZE)
.build(|ctx| {
let output_ptr = ctx.load_param_u64("output");
let thread_id = ctx.special_reg(PtxReg::TidX);
let mask_31 = ctx.mov_u32_imm(31);
let lane_id = ctx.and_u32(thread_id, mask_31);
let zero = ctx.mov_u32_imm(0);
let is_leader = ctx.setp_eq_u32(lane_id, zero);
ctx.branch_if_not(is_leader, "L_done");
let smem_base = ctx.shared_base_addr();
let fixed_off = ctx.mov_u64_imm(PAGE_SIZE as u64);
let fixed_addr = ctx.add_u64(smem_base, fixed_off);
let test_val_1 = ctx.mov_u32_imm(0xDEAD_0001);
ctx.st_generic_u32(fixed_addr, test_val_1);
let hash_idx = ctx.mov_u32_imm(100);
let hash_entry_off = ctx.mul_u32(hash_idx, 4);
let hash_entry_off_64 = ctx.cvt_u64_u32(hash_entry_off);
let page_size_64 = ctx.mov_u64_imm(PAGE_SIZE as u64);
let hash_table_base = ctx.add_u64(smem_base, page_size_64);
let hash_entry_addr = ctx.add_u64(hash_table_base, hash_entry_off_64);
let test_val_2 = ctx.mov_u32_imm(0xDEAD_0002);
ctx.st_generic_u32(hash_entry_addr, test_val_2);
let loaded = ctx.ld_generic_u32(hash_entry_addr);
ctx.st_global_u32(output_ptr, loaded);
ctx.label("L_done");
ctx.ret();
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_store_kernel_generates_valid_ptx() {
let kernel = HashStoreTestKernel::new();
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".entry hash_store_test"), "Missing entry point");
assert!(ptx.contains(".shared"), "Missing shared memory");
assert!(ptx.contains("cvta.shared"), "Missing cvta.shared for shared→generic conversion");
assert!(ptx.contains("st.u32"), "Missing st.u32 instructions");
assert!(ptx.contains("mul.lo"), "Missing mul.lo for offset");
assert!(ptx.contains("cvt.u64.u32"), "Missing cvt.u64.u32");
println!("=== HashStoreTestKernel PTX (relevant lines) ===");
for (i, line) in ptx.lines().enumerate() {
if line.contains("st.u32")
|| line.contains("mul.lo")
|| line.contains("add.u64")
|| line.contains("cvt.u64")
|| line.contains("cvta.to.shared")
{
println!("{:4}: {}", i, line);
}
}
}
#[test]
fn test_shared_mem_size_sufficient() {
assert!(SMEM_SIZE >= (PAGE_SIZE + HASH_TABLE_SIZE) as usize);
let max_entry_offset: u32 = 2047 * 4; let max_byte_accessed = PAGE_SIZE + max_entry_offset + 3; assert!((max_byte_accessed as usize) < SMEM_SIZE, "Max byte {} >= smem size {}", max_byte_accessed, SMEM_SIZE);
}
#[test]
fn test_offset_computation_no_overflow() {
let max_hash_idx: u32 = 2047;
let max_entry_off = max_hash_idx.checked_mul(4).expect("should not overflow");
assert_eq!(max_entry_off, 8188);
let total_offset = PAGE_SIZE.checked_add(max_entry_off).expect("should not overflow");
assert_eq!(total_offset, 12284);
}
}