use super::super::common::*;
#[test]
fn fkr_101_setup_only_test() {
if !cuda_available() {
eprintln!("SKIPPED: No CUDA");
return;
}
let ctx = CudaContext::new(0).expect("CUDA context");
let stream = CudaStream::new(&ctx).expect("CUDA stream");
const SMEM_SIZE: usize = 12544 * 3;
const PAGE_SIZE_VAL: u32 = 4096;
const HASH_TABLE_SIZE: u32 = 8192;
const STATE_OFF: u32 = PAGE_SIZE_VAL + 8192 + 128 + 4;
const HASH_TABLE_OFF: u32 = PAGE_SIZE_VAL;
let kernel = PtxKernel::new("setup_only")
.param(PtxType::U64, "input_batch")
.param(PtxType::U64, "output_batch")
.param(PtxType::U64, "output_sizes")
.param(PtxType::U64, "debug_buf")
.param(PtxType::U32, "batch_size")
.shared_memory(SMEM_SIZE)
.build(|ctx| {
let input_ptr = ctx.load_param_u64("input_batch");
let output_ptr = ctx.load_param_u64("output_batch");
let sizes_ptr = ctx.load_param_u64("output_sizes");
let debug_ptr = ctx.load_param_u64("debug_buf");
let batch_size = ctx.load_param_u32("batch_size");
let tid = ctx.special_reg(PtxReg::TidX);
let bid = ctx.special_reg(PtxReg::CtaIdX);
let warp_id = ctx.shr_u32_imm(tid, 5);
let lane_mask = ctx.mov_u32_imm(31);
let lane_id = ctx.and_u32(tid, lane_mask);
let zero_check = ctx.mov_u32_imm(0);
let is_leader = ctx.setp_eq_u32(lane_id, zero_check);
ctx.branch_if_not(is_leader, "L_end");
let warps_per_block = ctx.mov_u32_imm(3);
let block_offset = ctx.mul_lo_u32(bid, warps_per_block);
let page_idx = ctx.add_u32_reg(block_offset, warp_id);
let out_of_bounds = ctx.setp_ge_u32(page_idx, batch_size);
ctx.branch_if(out_of_bounds, "L_end");
ctx.emit_debug_marker(debug_ptr, 0xAA000000);
let warp_smem_size = ctx.mov_u32_imm(12544);
let warp_off = ctx.mul_lo_u32(warp_id, warp_smem_size);
ctx.emit_debug_marker(debug_ptr, 0xAA000001);
let page_size_val = ctx.mov_u32_imm(PAGE_SIZE_VAL);
let page_offset = ctx.mul_lo_u32(page_idx, page_size_val);
let page_offset_64 = ctx.cvt_u64_u32(page_offset);
let input_page_ptr = ctx.add_u64(input_ptr, page_offset_64);
let zero_val = ctx.mov_u32_imm(0);
ctx.st_shared_u32(warp_off, zero_val);
ctx.label("L_load_loop");
let idx = ctx.ld_shared_u32(warp_off);
let load_done = ctx.setp_ge_u32(idx, page_size_val);
ctx.branch_if(load_done, "L_load_done");
let idx_64 = ctx.cvt_u64_u32(idx);
let src_addr = ctx.add_u64(input_page_ptr, idx_64);
let val = ctx.ld_global_u32(src_addr);
let dst_off = ctx.add_u32_reg(warp_off, idx);
ctx.st_shared_u32(dst_off, val);
let four = ctx.mov_u32_imm(4);
let idx_next = ctx.add_u32_reg(idx, four);
ctx.st_shared_u32(warp_off, idx_next);
ctx.branch("L_load_loop");
ctx.label("L_load_done");
ctx.emit_debug_marker(debug_ptr, 0xAA000002);
let hash_base_off_val = ctx.mov_u32_imm(HASH_TABLE_OFF);
let hash_base_off = ctx.add_u32_reg(warp_off, hash_base_off_val);
ctx.st_shared_u32(warp_off, zero_val);
let invalid_marker = ctx.mov_u32_imm(0xFFFFFFFF);
let hash_table_size = ctx.mov_u32_imm(HASH_TABLE_SIZE);
ctx.label("L_hash_init");
let h_idx = ctx.ld_shared_u32(warp_off);
let init_done = ctx.setp_ge_u32(h_idx, hash_table_size);
ctx.branch_if(init_done, "L_hash_done");
let hash_off = ctx.add_u32_reg(hash_base_off, h_idx);
ctx.st_shared_u32(hash_off, invalid_marker);
let h_next = ctx.add_u32_reg(h_idx, four);
ctx.st_shared_u32(warp_off, h_next);
ctx.branch("L_hash_init");
ctx.label("L_hash_done");
ctx.emit_debug_marker(debug_ptr, 0xAA000003);
let state_off_val = ctx.mov_u32_imm(STATE_OFF);
let state_off = ctx.add_u32_reg(warp_off, state_off_val);
let four_imm = ctx.mov_u32_imm(4);
let eight_imm = ctx.mov_u32_imm(8);
let out_pos_state_off = ctx.add_u32_reg(state_off, four_imm);
let anchor_state_off = ctx.add_u32_reg(state_off, eight_imm);
ctx.st_shared_u32(state_off, zero_val);
ctx.st_shared_u32(out_pos_state_off, zero_val);
ctx.st_shared_u32(anchor_state_off, zero_val);
let output_size = ctx.mov_u32_imm(4352);
let output_offset = ctx.mul_lo_u32(page_idx, output_size);
let output_offset_64 = ctx.cvt_u64_u32(output_offset);
let _output_page_ptr = ctx.add_u64(output_ptr, output_offset_64);
let limit = ctx.mov_u32_imm(PAGE_SIZE_VAL - 12);
let lz4_prime = ctx.mov_u32_imm(0x9E3779B1);
let hash_shift = ctx.mov_u32_imm(21);
let hash_mask = ctx.mov_u32_imm(0x7FF);
let final_pos = ctx.mov_u32_imm(0);
let page_idx_64 = ctx.cvt_u64_u32(page_idx);
let four_64 = ctx.mov_u64_imm(4);
let size_offset = ctx.mul_u64_reg(page_idx_64, four_64);
let size_addr = ctx.add_u64(sizes_ptr, size_offset);
ctx.st_global_u32(size_addr, final_pos);
ctx.emit_debug_marker(debug_ptr, 0xAA000004);
let _suppress = limit;
let _suppress2 = lz4_prime;
let _suppress3 = hash_shift;
let _suppress4 = hash_mask;
let _suppress5 = invalid_marker;
ctx.label("L_end");
ctx.ret();
});
let ptx = PtxModule::new()
.version(8, 0)
.target("sm_89")
.address_size(64)
.add_kernel(kernel)
.emit();
println!("=== Setup-Only PTX ===");
for (i, line) in ptx.lines().take(50).enumerate() {
println!("{:4}: {}", i + 1, line);
}
let mut input_buf: GpuBuffer<u8> = GpuBuffer::new(&ctx, 4096).unwrap();
input_buf
.copy_from_host(&(0..4096u32).map(|i| (i % 256) as u8).collect::<Vec<_>>())
.unwrap();
let mut output_buf: GpuBuffer<u8> = GpuBuffer::new(&ctx, 4352).unwrap();
let mut sizes_buf: GpuBuffer<u32> = GpuBuffer::new(&ctx, 1).unwrap();
let mut debug_buf: GpuBuffer<u32> = GpuBuffer::new(&ctx, 64).unwrap();
debug_buf.copy_from_host(&vec![0u32; 64]).unwrap();
let mut module = CudaModule::from_ptx(&ctx, &ptx).expect("PTX load");
let config = LaunchConfig {
grid: (1, 1, 1),
block: (96, 1, 1),
shared_mem: 0,
};
let batch_size: u32 = 1;
let mut args: [*mut c_void; 5] = [
input_buf.as_kernel_arg(),
output_buf.as_kernel_arg(),
sizes_buf.as_kernel_arg(),
debug_buf.as_kernel_arg(),
&batch_size as *const u32 as *mut c_void,
];
println!("Launching setup-only kernel...");
unsafe {
stream
.launch_kernel(&mut module, "setup_only", &config, &mut args)
.expect("Launch");
}
let sync_result = stream.synchronize();
let mut output = vec![0u32; 64];
debug_buf.copy_to_host(&mut output).unwrap();
println!("Counter: {}", output[0]);
for i in 0..output[0].min(10) as usize {
let m = output[i + 1];
let name = match m {
0xAA000000 => "BOUNDS_OK",
0xAA000001 => "WARP_OFF",
0xAA000002 => "LOAD_DONE",
0xAA000003 => "HASH_DONE",
0xAA000004 => "ALL_DONE",
_ => "UNKNOWN",
};
println!(" [{:2}] 0x{:08X} ({})", i, m, name);
}
if let Err(e) = sync_result {
panic!("Setup-only test crashed: {:?}", e);
}
assert_eq!(output[0], 5, "Should have 5 markers");
println!("Setup-only test PASSED!");
}