#![allow(clippy::similar_names)]
use super::{
LZ4_HASH_BITS, LZ4_HASH_MULT, LZ4_HASH_SIZE, LZ4_MAX_OFFSET, LZ4_MIN_MATCH, PAGE_SIZE,
};
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct Lz4WarpCompressKernel {
batch_size: u32,
}
impl Lz4WarpCompressKernel {
#[must_use]
pub fn new(batch_size: u32) -> Self {
Self { batch_size }
}
#[must_use]
pub fn batch_size(&self) -> u32 {
self.batch_size
}
#[must_use]
pub fn grid_dim(&self) -> (u32, u32, u32) {
let pages_per_block = 4;
let num_blocks = (self.batch_size + pages_per_block - 1) / pages_per_block;
(num_blocks, 1, 1)
}
#[must_use]
pub fn block_dim(&self) -> (u32, u32, u32) {
(128, 1, 1)
}
#[must_use]
pub fn shared_memory_bytes(&self) -> usize {
4 * (PAGE_SIZE as usize + LZ4_HASH_SIZE as usize * 2)
}
#[must_use]
pub fn emit_wgsl(&self) -> String {
format!(
r"// LZ4 Warp-Cooperative Compression Kernel (WGSL)
// Generated by trueno-gpu - Pure Rust GPU code generation
// WebGPU cross-platform: Vulkan, Metal, DX12, WebGPU
// Constants
const PAGE_SIZE: u32 = 4096u;
const SUBGROUP_SIZE: u32 = 32u;
const PAGES_PER_WORKGROUP: u32 = 4u;
// Bindings
@group(0) @binding(0) var<storage, read> input_batch: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_batch: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_sizes: array<u32>;
// Workgroup shared memory (48KB per workgroup)
var<workgroup> smem: array<u32, 12288>; // 48KB / 4 bytes
@compute @workgroup_size(128, 1, 1)
fn lz4_compress_warp(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {{
let batch_size: u32 = {batch_size}u;
// Calculate warp and lane IDs (WGSL uses subgroups)
let thread_id = local_id.x;
let warp_id = thread_id / SUBGROUP_SIZE;
let lane_id = thread_id % SUBGROUP_SIZE;
// Calculate page assignment
let page_id = workgroup_id.x * PAGES_PER_WORKGROUP + warp_id;
// Bounds check
if (page_id >= batch_size) {{
return;
}}
// Calculate memory offsets
let page_offset = page_id * (PAGE_SIZE / 4u); // In u32 units
let load_base = lane_id * 32u; // 128 bytes per thread = 32 u32s
let smem_warp_base = warp_id * (PAGE_SIZE / 4u);
// Phase 1: Cooperative load from global to shared memory
for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
let chunk_off = load_base + i * 8u; // 8 u32s per iteration
let global_idx = page_offset + chunk_off;
let smem_idx = smem_warp_base + chunk_off;
// Load 8 u32s (32 bytes)
smem[smem_idx + 0u] = input_batch[global_idx + 0u];
smem[smem_idx + 1u] = input_batch[global_idx + 1u];
smem[smem_idx + 2u] = input_batch[global_idx + 2u];
smem[smem_idx + 3u] = input_batch[global_idx + 3u];
smem[smem_idx + 4u] = input_batch[global_idx + 4u];
smem[smem_idx + 5u] = input_batch[global_idx + 5u];
smem[smem_idx + 6u] = input_batch[global_idx + 6u];
smem[smem_idx + 7u] = input_batch[global_idx + 7u];
}}
// Workgroup barrier
workgroupBarrier();
// Phase 2: Zero-page detection with parallel reduction
// Each thread checks if its 128 bytes are all zeros
var thread_or: u32 = 0u;
for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
let chunk_off = load_base + i * 8u;
let smem_idx = smem_warp_base + chunk_off;
thread_or = thread_or | smem[smem_idx + 0u];
thread_or = thread_or | smem[smem_idx + 1u];
thread_or = thread_or | smem[smem_idx + 2u];
thread_or = thread_or | smem[smem_idx + 3u];
thread_or = thread_or | smem[smem_idx + 4u];
thread_or = thread_or | smem[smem_idx + 5u];
thread_or = thread_or | smem[smem_idx + 6u];
thread_or = thread_or | smem[smem_idx + 7u];
}}
// Store each thread's result for reduction
let reduction_idx = smem_warp_base + 1024u + lane_id; // Use space after page data
smem[reduction_idx] = thread_or;
workgroupBarrier();
// Lane 0 reduces all 32 values and writes output size
if (lane_id == 0u) {{
var page_or: u32 = 0u;
for (var j: u32 = 0u; j < SUBGROUP_SIZE; j = j + 1u) {{
page_or = page_or | smem[smem_warp_base + 1024u + j];
}}
// Zero page: compressed to 20 bytes, non-zero: uncompressed
if (page_or == 0u) {{
output_sizes[page_id] = 20u; // LZ4 minimal encoding
}} else {{
output_sizes[page_id] = PAGE_SIZE;
}}
}}
workgroupBarrier();
// Phase 3: Cooperative store from shared to global memory
for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
let chunk_off = load_base + i * 8u;
let global_idx = page_offset + chunk_off;
let smem_idx = smem_warp_base + chunk_off;
output_batch[global_idx + 0u] = smem[smem_idx + 0u];
output_batch[global_idx + 1u] = smem[smem_idx + 1u];
output_batch[global_idx + 2u] = smem[smem_idx + 2u];
output_batch[global_idx + 3u] = smem[smem_idx + 3u];
output_batch[global_idx + 4u] = smem[smem_idx + 4u];
output_batch[global_idx + 5u] = smem[smem_idx + 5u];
output_batch[global_idx + 6u] = smem[smem_idx + 6u];
output_batch[global_idx + 7u] = smem[smem_idx + 7u];
}}
}}
",
batch_size = self.batch_size
)
}
}
impl Kernel for Lz4WarpCompressKernel {
fn name(&self) -> &str {
"lz4_compress_warp"
}
fn build_ptx(&self) -> PtxKernel {
PtxKernel::new(self.name())
.param(PtxType::U64, "input_batch")
.param(PtxType::U64, "output_batch")
.param(PtxType::U64, "output_sizes")
.param(PtxType::U32, "batch_size")
.shared_memory(self.shared_memory_bytes())
.build(|ctx| {
let block_id = ctx.special_reg(PtxReg::CtaIdX);
let thread_id = ctx.special_reg(PtxReg::TidX);
let shift_5 = ctx.mov_u32_imm(5);
let warp_id = ctx.shr_u32(thread_id, shift_5);
let page_id = ctx.mul_wide_u32(block_id, 4);
let warp_id_64 = ctx.cvt_u64_u32(warp_id);
let page_id = ctx.add_u64(page_id, warp_id_64);
let page_id_32 = ctx.cvt_u32_u64(page_id);
let mask_31 = ctx.mov_u32_imm(31);
let lane_id = ctx.and_u32(thread_id, mask_31);
let batch_param = ctx.load_param_u32("batch_size");
let in_bounds_pred = ctx.setp_lt_u32(page_id_32, batch_param);
let page_offset = ctx.mul_wide_u32(page_id_32, PAGE_SIZE);
let input_ptr = ctx.load_param_u64("input_batch");
let input_page_ptr = ctx.add_u64(input_ptr, page_offset);
let output_ptr = ctx.load_param_u64("output_batch");
let output_page_ptr = ctx.add_u64(output_ptr, page_offset);
const WARP_SMEM_SIZE: u32 = PAGE_SIZE + LZ4_HASH_SIZE * 2; let warp_smem_offset = ctx.mul_u32(warp_id, WARP_SMEM_SIZE);
let warp_smem_offset_64 = ctx.cvt_u64_u32(warp_smem_offset);
let raw_smem_base = ctx.shared_base_addr();
let smem_base = ctx.add_u64(raw_smem_base, warp_smem_offset_64);
let load_base = ctx.mul_u32(lane_id, 128);
let imm4 = ctx.mov_u64_imm(4);
let imm8 = ctx.mov_u64_imm(8);
let imm12 = ctx.mov_u64_imm(12);
let imm16 = ctx.mov_u64_imm(16);
let imm20 = ctx.mov_u64_imm(20);
let imm24 = ctx.mov_u64_imm(24);
let imm28 = ctx.mov_u64_imm(28);
ctx.branch_if_not(in_bounds_pred, "L_skip_global_load");
for i in 0..4u32 {
let chunk_off = ctx.add_u32(load_base, i * 32);
let chunk_off_64 = ctx.cvt_u64_u32(chunk_off);
let load_addr = ctx.add_u64(input_page_ptr, chunk_off_64);
let d0 = ctx.ld_global_u32(load_addr);
let off4 = ctx.add_u64(load_addr, imm4);
let d1 = ctx.ld_global_u32(off4);
let off8 = ctx.add_u64(load_addr, imm8);
let d2 = ctx.ld_global_u32(off8);
let off12 = ctx.add_u64(load_addr, imm12);
let d3 = ctx.ld_global_u32(off12);
let off16 = ctx.add_u64(load_addr, imm16);
let d4 = ctx.ld_global_u32(off16);
let off20 = ctx.add_u64(load_addr, imm20);
let d5 = ctx.ld_global_u32(off20);
let off24 = ctx.add_u64(load_addr, imm24);
let d6 = ctx.ld_global_u32(off24);
let off28 = ctx.add_u64(load_addr, imm28);
let d7 = ctx.ld_global_u32(off28);
let smem_off = ctx.add_u64(smem_base, chunk_off_64);
ctx.st_generic_u32(smem_off, d0);
let smem_4 = ctx.add_u64(smem_off, imm4);
ctx.st_generic_u32(smem_4, d1);
let smem_8 = ctx.add_u64(smem_off, imm8);
ctx.st_generic_u32(smem_8, d2);
let smem_12 = ctx.add_u64(smem_off, imm12);
ctx.st_generic_u32(smem_12, d3);
let smem_16 = ctx.add_u64(smem_off, imm16);
ctx.st_generic_u32(smem_16, d4);
let smem_20 = ctx.add_u64(smem_off, imm20);
ctx.st_generic_u32(smem_20, d5);
let smem_24 = ctx.add_u64(smem_off, imm24);
ctx.st_generic_u32(smem_24, d6);
let smem_28 = ctx.add_u64(smem_off, imm28);
ctx.st_generic_u32(smem_28, d7);
}
ctx.branch("L_after_global_load");
ctx.label("L_skip_global_load");
let zero_val = ctx.mov_u32_imm(0);
for i in 0..4u32 {
let chunk_off = ctx.add_u32(load_base, i * 32);
let chunk_off_64 = ctx.cvt_u64_u32(chunk_off);
let smem_off = ctx.add_u64(smem_base, chunk_off_64);
ctx.st_generic_u32(smem_off, zero_val);
let smem_4 = ctx.add_u64(smem_off, imm4);
ctx.st_generic_u32(smem_4, zero_val);
let smem_8 = ctx.add_u64(smem_off, imm8);
ctx.st_generic_u32(smem_8, zero_val);
let smem_12 = ctx.add_u64(smem_off, imm12);
ctx.st_generic_u32(smem_12, zero_val);
let smem_16 = ctx.add_u64(smem_off, imm16);
ctx.st_generic_u32(smem_16, zero_val);
let smem_20 = ctx.add_u64(smem_off, imm20);
ctx.st_generic_u32(smem_20, zero_val);
let smem_24 = ctx.add_u64(smem_off, imm24);
ctx.st_generic_u32(smem_24, zero_val);
let smem_28 = ctx.add_u64(smem_off, imm28);
ctx.st_generic_u32(smem_28, zero_val);
}
ctx.label("L_after_global_load");
ctx.bar_sync(0);
let chunk_val = ctx.mov_u32_imm(0);
for i in 0..4u32 {
let chunk_off = ctx.add_u32(load_base, i * 32);
let chunk_off_64 = ctx.cvt_u64_u32(chunk_off);
let smem_off = ctx.add_u64(smem_base, chunk_off_64);
let d0 = ctx.ld_generic_u32(smem_off);
let chunk_val = ctx.or_u32(chunk_val, d0);
let off4 = ctx.add_u64(smem_off, imm4);
let d1 = ctx.ld_generic_u32(off4);
let chunk_val = ctx.or_u32(chunk_val, d1);
let off8 = ctx.add_u64(smem_off, imm8);
let d2 = ctx.ld_generic_u32(off8);
let chunk_val = ctx.or_u32(chunk_val, d2);
let off12 = ctx.add_u64(smem_off, imm12);
let d3 = ctx.ld_generic_u32(off12);
let chunk_val = ctx.or_u32(chunk_val, d3);
let off16 = ctx.add_u64(smem_off, imm16);
let d4 = ctx.ld_generic_u32(off16);
let chunk_val = ctx.or_u32(chunk_val, d4);
let off20 = ctx.add_u64(smem_off, imm20);
let d5 = ctx.ld_generic_u32(off20);
let chunk_val = ctx.or_u32(chunk_val, d5);
let off24 = ctx.add_u64(smem_off, imm24);
let d6 = ctx.ld_generic_u32(off24);
let chunk_val = ctx.or_u32(chunk_val, d6);
let off28 = ctx.add_u64(smem_off, imm28);
let d7 = ctx.ld_generic_u32(off28);
let _ = ctx.or_u32(chunk_val, d7);
}
let lane_off_bytes = ctx.mul_u32(lane_id, 4);
let reduction_off = ctx.add_u32(lane_off_bytes, PAGE_SIZE); let reduction_off_64 = ctx.cvt_u64_u32(reduction_off);
let reduction_addr = ctx.add_u64(smem_base, reduction_off_64);
ctx.st_generic_u32(reduction_addr, chunk_val);
ctx.bar_sync(0);
let zero = ctx.mov_u32_imm(0);
let is_leader = ctx.setp_eq_u32(lane_id, zero);
let can_write_size = ctx.and_pred(is_leader, in_bounds_pred);
ctx.branch_if_not(can_write_size, "L_not_leader");
let page_or = ctx.mov_u32_imm(0);
for lane in 0..32u32 {
let lane_off = ctx.mov_u32_imm(PAGE_SIZE + lane * 4);
let lane_off_64 = ctx.cvt_u64_u32(lane_off);
let lane_addr = ctx.add_u64(smem_base, lane_off_64);
let lane_val = ctx.ld_generic_u32(lane_addr);
let _ = ctx.or_u32(page_or, lane_val);
}
let is_zero_page = ctx.setp_eq_u32(page_or, zero);
let size_ptr = ctx.load_param_u64("output_sizes");
let size_off = ctx.mul_wide_u32(page_id_32, 4);
let size_addr = ctx.add_u64(size_ptr, size_off);
let compressed_zero_size = ctx.mov_u32_imm(20); let uncompressed_size = ctx.mov_u32_imm(PAGE_SIZE);
ctx.branch_if(is_zero_page, "L_write_zero_size");
ctx.branch("L_compress_start");
ctx.label("L_compress_done");
ctx.branch("L_write_compressed_size");
ctx.label("L_compress_start");
let lz4_prime = ctx.mov_u32_imm(LZ4_HASH_MULT);
let hash_shift = ctx.mov_u32_imm(32 - LZ4_HASH_BITS); let hash_mask = ctx.mov_u32_imm(LZ4_HASH_SIZE - 1);
let state_off = ctx.mov_u32_imm(PAGE_SIZE + LZ4_HASH_SIZE * 2); let state_off_64 = ctx.cvt_u64_u32(state_off);
let state_base = ctx.add_u64(smem_base, state_off_64);
let zero_val = ctx.mov_u32_imm(0);
ctx.st_generic_u32(state_base, zero_val); let imm4_state = ctx.mov_u64_imm(4);
let out_pos_addr = ctx.add_u64(state_base, imm4_state);
ctx.st_generic_u32(out_pos_addr, zero_val); let imm8_state = ctx.mov_u64_imm(8);
let anchor_addr = ctx.add_u64(state_base, imm8_state);
ctx.st_generic_u32(anchor_addr, zero_val);
let limit = ctx.mov_u32_imm(PAGE_SIZE - 12);
ctx.label("L_compress_loop");
let in_pos = ctx.ld_generic_u32(state_base);
let out_pos = ctx.ld_generic_u32(out_pos_addr);
let anchor = ctx.ld_generic_u32(anchor_addr);
let at_limit = ctx.setp_ge_u32(in_pos, limit);
ctx.branch_if(at_limit, "L_emit_remaining");
let in_pos_64 = ctx.cvt_u64_u32(in_pos);
let curr_addr = ctx.add_u64(smem_base, in_pos_64);
let curr_val = ctx.ld_generic_u32(curr_addr);
let hash_tmp = ctx.mul_lo_u32(curr_val, lz4_prime);
let hash_shifted = ctx.shr_u32(hash_tmp, hash_shift);
let hash_idx = ctx.and_u32(hash_shifted, hash_mask);
let hash_table_off = ctx.mov_u32_imm(PAGE_SIZE);
let hash_table_off_64 = ctx.cvt_u64_u32(hash_table_off);
let hash_table_base = ctx.add_u64(smem_base, hash_table_off_64);
let hash_entry_off = ctx.mul_u32(hash_idx, 2);
let hash_entry_off_64 = ctx.cvt_u64_u32(hash_entry_off);
let hash_entry_addr = ctx.add_u64(hash_table_base, hash_entry_off_64);
let match_pos_u16 = ctx.ld_generic_u16(hash_entry_addr);
let match_pos = ctx.cvt_u32_u16(match_pos_u16);
ctx.st_generic_u16(hash_entry_addr, in_pos);
let invalid_marker = ctx.mov_u32_imm(0xFFFF);
let no_match_candidate = ctx.setp_eq_u32(match_pos, invalid_marker);
ctx.branch_if(no_match_candidate, "L_no_match");
let offset = ctx.sub_u32_reg(in_pos, match_pos);
let max_offset_plus_one = ctx.mov_u32_imm(LZ4_MAX_OFFSET + 1);
let offset_too_large = ctx.setp_ge_u32(offset, max_offset_plus_one);
ctx.branch_if(offset_too_large, "L_no_match");
let zero_check = ctx.mov_u32_imm(0);
let offset_is_zero = ctx.setp_eq_u32(offset, zero_check);
ctx.branch_if(offset_is_zero, "L_no_match");
let match_pos_64 = ctx.cvt_u64_u32(match_pos);
let match_addr = ctx.add_u64(smem_base, match_pos_64);
let match_val = ctx.ld_generic_u32(match_addr);
ctx.label("L_check_match");
let vals_equal = ctx.setp_eq_u32(curr_val, match_val);
ctx.branch_if_not(vals_equal, "L_no_match");
ctx.label("L_found_match");
ctx.label("L_encode_sequence");
let literal_len = ctx.sub_u32_reg(in_pos, anchor);
let _match_len = ctx.mov_u32_imm(LZ4_MIN_MATCH);
let fifteen = ctx.mov_u32_imm(15);
let lit_ge_15 = ctx.setp_ge_u32(literal_len, fifteen);
let token_lit = ctx.selp_u32(lit_ge_15, fifteen, literal_len);
let four_bits = ctx.mov_u32_imm(4);
let token_lit_shifted = ctx.shl_u32(token_lit, four_bits);
let token = token_lit_shifted;
let out_pos_64 = ctx.cvt_u64_u32(out_pos);
let out_addr = ctx.add_u64(output_page_ptr, out_pos_64);
ctx.st_global_u8(out_addr, token);
let _one_32 = ctx.mov_u32_imm(1);
let out_pos_1 = ctx.add_u32(out_pos, 1);
ctx.branch_if_not(lit_ge_15, "L_skip_ext_lit_len");
let lit_minus_15 = ctx.sub_u32_reg(literal_len, fifteen);
let out_pos_1_64 = ctx.cvt_u64_u32(out_pos_1);
let ext_addr = ctx.add_u64(output_page_ptr, out_pos_1_64);
ctx.st_global_u8(ext_addr, lit_minus_15);
let out_pos_2 = ctx.add_u32(out_pos_1, 1);
ctx.st_generic_u32(out_pos_addr, out_pos_2);
ctx.branch("L_copy_literals");
ctx.label("L_skip_ext_lit_len");
ctx.st_generic_u32(out_pos_addr, out_pos_1);
ctx.label("L_copy_literals");
let out_pos_cur = ctx.ld_generic_u32(out_pos_addr);
let lit_copy_idx = ctx.mov_u32_imm(0);
ctx.st_generic_u32(state_base, lit_copy_idx);
ctx.label("L_copy_lit_loop");
let copy_idx = ctx.ld_generic_u32(state_base);
let copy_done = ctx.setp_ge_u32(copy_idx, literal_len);
ctx.branch_if(copy_done, "L_copy_lit_done");
let src_off = ctx.add_u32_reg(anchor, copy_idx);
let src_off_64 = ctx.cvt_u64_u32(src_off);
let src_addr = ctx.add_u64(smem_base, src_off_64);
let byte_val = ctx.ld_generic_u8(src_addr);
let dst_off = ctx.add_u32_reg(out_pos_cur, copy_idx);
let dst_off_64 = ctx.cvt_u64_u32(dst_off);
let dst_addr = ctx.add_u64(output_page_ptr, dst_off_64);
ctx.st_global_u8(dst_addr, byte_val);
let copy_idx_next = ctx.add_u32(copy_idx, 1);
ctx.st_generic_u32(state_base, copy_idx_next);
ctx.branch("L_copy_lit_loop");
ctx.label("L_copy_lit_done");
let out_pos_after_lit = ctx.add_u32_reg(out_pos_cur, literal_len);
let out_pos_after_lit_64 = ctx.cvt_u64_u32(out_pos_after_lit);
let offset_addr = ctx.add_u64(output_page_ptr, out_pos_after_lit_64);
let mask_ff = ctx.mov_u32_imm(0xFF);
let offset_lo = ctx.and_u32(offset, mask_ff);
ctx.st_global_u8(offset_addr, offset_lo);
let imm_1 = ctx.mov_u64_imm(1);
let offset_addr_1 = ctx.add_u64(offset_addr, imm_1);
let eight_bits = ctx.mov_u32_imm(8);
let offset_hi = ctx.shr_u32(offset, eight_bits);
ctx.st_global_u8(offset_addr_1, offset_hi);
let out_pos_after_offset = ctx.add_u32(out_pos_after_lit, 2);
let new_anchor = ctx.add_u32(in_pos, LZ4_MIN_MATCH);
ctx.st_generic_u32(anchor_addr, new_anchor);
ctx.st_generic_u32(state_base, new_anchor);
ctx.st_generic_u32(out_pos_addr, out_pos_after_offset);
ctx.branch("L_compress_loop");
ctx.label("L_no_match");
let in_pos_next = ctx.add_u32(in_pos, 1);
ctx.st_generic_u32(state_base, in_pos_next);
ctx.branch("L_compress_loop");
ctx.label("L_emit_remaining");
let _final_in_pos = ctx.ld_generic_u32(state_base);
let final_out_pos = ctx.ld_generic_u32(out_pos_addr);
let final_anchor = ctx.ld_generic_u32(anchor_addr);
let page_size_val = ctx.mov_u32_imm(PAGE_SIZE);
let final_lit_len = ctx.sub_u32_reg(page_size_val, final_anchor);
let no_remaining = ctx.setp_eq_u32(final_lit_len, zero_val);
ctx.branch_if(no_remaining, "L_finalize_size");
let final_lit_ge_15 = ctx.setp_ge_u32(final_lit_len, fifteen);
let final_token_lit = ctx.selp_u32(final_lit_ge_15, fifteen, final_lit_len);
let final_token = ctx.shl_u32(final_token_lit, four_bits);
let final_out_64 = ctx.cvt_u64_u32(final_out_pos);
let final_token_addr = ctx.add_u64(output_page_ptr, final_out_64);
ctx.st_global_u8(final_token_addr, final_token);
let final_out_1 = ctx.add_u32(final_out_pos, 1);
ctx.branch_if_not(final_lit_ge_15, "L_final_copy_literals");
let final_ext_len = ctx.sub_u32_reg(final_lit_len, fifteen);
let final_ext_64 = ctx.cvt_u64_u32(final_out_1);
let final_ext_addr = ctx.add_u64(output_page_ptr, final_ext_64);
ctx.st_global_u8(final_ext_addr, final_ext_len);
let final_out_2 = ctx.add_u32(final_out_1, 1);
ctx.st_generic_u32(out_pos_addr, final_out_2);
ctx.branch("L_final_do_copy");
ctx.label("L_final_copy_literals");
ctx.st_generic_u32(out_pos_addr, final_out_1);
ctx.label("L_final_do_copy");
let final_out_cur = ctx.ld_generic_u32(out_pos_addr);
let final_copy_idx = ctx.mov_u32_imm(0);
ctx.st_generic_u32(state_base, final_copy_idx);
ctx.label("L_final_copy_loop");
let fcopy_idx = ctx.ld_generic_u32(state_base);
let fcopy_done = ctx.setp_ge_u32(fcopy_idx, final_lit_len);
ctx.branch_if(fcopy_done, "L_finalize_size");
let fsrc_off = ctx.add_u32_reg(final_anchor, fcopy_idx);
let fsrc_off_64 = ctx.cvt_u64_u32(fsrc_off);
let fsrc_addr = ctx.add_u64(smem_base, fsrc_off_64);
let fbyte_val = ctx.ld_generic_u8(fsrc_addr);
let fdst_off = ctx.add_u32_reg(final_out_cur, fcopy_idx);
let fdst_off_64 = ctx.cvt_u64_u32(fdst_off);
let fdst_addr = ctx.add_u64(output_page_ptr, fdst_off_64);
ctx.st_global_u8(fdst_addr, fbyte_val);
let fcopy_next = ctx.add_u32(fcopy_idx, 1);
ctx.st_generic_u32(state_base, fcopy_next);
ctx.branch("L_final_copy_loop");
ctx.label("L_finalize_size");
let final_size = ctx.ld_generic_u32(out_pos_addr);
let total_size = ctx.add_u32_reg(final_size, final_lit_len);
ctx.st_generic_u32(out_pos_addr, total_size);
ctx.label("L_lane_sync");
ctx.bar_sync(2);
ctx.branch("L_compress_done");
ctx.label("L_write_compressed_size");
let compressed_size = ctx.ld_generic_u32(out_pos_addr);
let expanded = ctx.setp_ge_u32(compressed_size, uncompressed_size);
let final_reported_size =
ctx.selp_u32(expanded, uncompressed_size, compressed_size);
ctx.st_global_u32(size_addr, final_reported_size);
ctx.branch("L_after_size_write");
ctx.label("L_write_zero_size");
ctx.st_global_u32(size_addr, compressed_zero_size);
ctx.label("L_after_size_write");
ctx.label("L_not_leader");
ctx.bar_sync(0);
ctx.branch_if_not(in_bounds_pred, "L_exit");
for i in 0..4u32 {
let chunk_off = ctx.add_u32(load_base, i * 32);
let chunk_off_64 = ctx.cvt_u64_u32(chunk_off);
let smem_off = ctx.add_u64(smem_base, chunk_off_64);
let d0 = ctx.ld_generic_u32(smem_off);
let ld_4 = ctx.add_u64(smem_off, imm4);
let d1 = ctx.ld_generic_u32(ld_4);
let ld_8 = ctx.add_u64(smem_off, imm8);
let d2 = ctx.ld_generic_u32(ld_8);
let ld_12 = ctx.add_u64(smem_off, imm12);
let d3 = ctx.ld_generic_u32(ld_12);
let ld_16 = ctx.add_u64(smem_off, imm16);
let d4 = ctx.ld_generic_u32(ld_16);
let ld_20 = ctx.add_u64(smem_off, imm20);
let d5 = ctx.ld_generic_u32(ld_20);
let ld_24 = ctx.add_u64(smem_off, imm24);
let d6 = ctx.ld_generic_u32(ld_24);
let ld_28 = ctx.add_u64(smem_off, imm28);
let d7 = ctx.ld_generic_u32(ld_28);
let store_addr = ctx.add_u64(output_page_ptr, chunk_off_64);
ctx.st_global_u32(store_addr, d0);
let st_4 = ctx.add_u64(store_addr, imm4);
ctx.st_global_u32(st_4, d1);
let st_8 = ctx.add_u64(store_addr, imm8);
ctx.st_global_u32(st_8, d2);
let st_12 = ctx.add_u64(store_addr, imm12);
ctx.st_global_u32(st_12, d3);
let st_16 = ctx.add_u64(store_addr, imm16);
ctx.st_global_u32(st_16, d4);
let st_20 = ctx.add_u64(store_addr, imm20);
ctx.st_global_u32(st_20, d5);
let st_24 = ctx.add_u64(store_addr, imm24);
ctx.st_global_u32(st_24, d6);
let st_28 = ctx.add_u64(store_addr, imm28);
ctx.st_global_u32(st_28, d7);
}
ctx.label("L_exit");
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernels::Kernel;
#[test]
fn test_f051_kernel_creation() {
let kernel = Lz4WarpCompressKernel::new(1000);
assert_eq!(kernel.batch_size(), 1000);
assert_eq!(kernel.name(), "lz4_compress_warp");
}
#[test]
fn test_f051_grid_dimensions() {
let kernel = Lz4WarpCompressKernel::new(1000);
let (gx, gy, gz) = kernel.grid_dim();
assert_eq!(gx, 250);
assert_eq!(gy, 1);
assert_eq!(gz, 1);
}
#[test]
fn test_f051_block_dimensions() {
let kernel = Lz4WarpCompressKernel::new(1000);
let (bx, by, bz) = kernel.block_dim();
assert_eq!(bx, 128);
assert_eq!(by, 1);
assert_eq!(bz, 1);
}
#[test]
fn test_f052_shared_memory_size() {
let kernel = Lz4WarpCompressKernel::new(100);
let smem = kernel.shared_memory_bytes();
assert!(smem > 0);
assert!(smem <= 100 * 1024);
}
#[test]
fn test_f053_ptx_generation_valid() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"), "Missing PTX version");
assert!(ptx.contains(".target"), "Missing PTX target");
assert!(ptx.contains(".entry"), "Missing entry point");
}
#[test]
fn test_f053_ptx_has_parameters() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("input_batch"));
assert!(ptx.contains("output_batch"));
assert!(ptx.contains("output_sizes"));
assert!(ptx.contains("batch_size"));
}
#[test]
fn test_f053_ptx_has_shared_memory() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".shared"));
}
#[test]
fn test_f054_barrier_safety() {
let kernel = Lz4WarpCompressKernel::new(100);
let result = kernel.analyze_barrier_safety();
assert!(
result.is_safe,
"LZ4 kernel should be barrier-safe: {:?}",
result.violations
);
}
#[test]
fn test_f055_kernel_name_deterministic() {
let k1 = Lz4WarpCompressKernel::new(100);
let k2 = Lz4WarpCompressKernel::new(100);
assert_eq!(k1.name(), k2.name());
}
#[test]
fn test_f056_ptx_has_barrier_sync() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("bar.sync"));
}
#[test]
fn test_f059_grid_covers_all_pages() {
for batch_size in [1, 4, 5, 100, 1000, 18432] {
let kernel = Lz4WarpCompressKernel::new(batch_size);
let (gx, _, _) = kernel.grid_dim();
let (bx, _, _) = kernel.block_dim();
let warps_per_block = bx / 32;
let total_warps = gx * warps_per_block;
assert!(total_warps >= batch_size);
}
}
#[test]
fn test_f060_module_emission() {
let kernel = Lz4WarpCompressKernel::new(100);
let module = kernel.as_module();
let ptx = module.emit();
assert!(ptx.contains(".version 8.0"));
assert!(ptx.contains(".target sm_89"));
}
#[test]
fn test_f061_ptx_validates_with_ptxas() {
use std::io::Write;
use std::process::Command;
let ptxas_check = Command::new("which").arg("ptxas").output();
if ptxas_check.is_err() || !ptxas_check.unwrap().status.success() {
eprintln!("ptxas not available, skipping validation");
return;
}
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
let mut tmpfile = std::env::temp_dir();
tmpfile.push("lz4_compress_warp.ptx");
let mut f = std::fs::File::create(&tmpfile).expect("Failed to create temp file");
f.write_all(ptx.as_bytes()).expect("Failed to write PTX");
let output = Command::new("ptxas")
.args(["-arch=sm_89", tmpfile.to_str().unwrap(), "-o", "/dev/null"])
.output()
.expect("Failed to run ptxas");
let _ = std::fs::remove_file(&tmpfile);
assert!(
output.status.success(),
"ptxas validation failed:\nstdout: {}\nstderr: {}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
#[test]
fn test_f062_wgsl_generation_valid() {
let kernel = Lz4WarpCompressKernel::new(100);
let wgsl = kernel.emit_wgsl();
assert!(wgsl.contains("@compute"), "Missing @compute attribute");
assert!(wgsl.contains("@workgroup_size"), "Missing workgroup_size");
assert!(
wgsl.contains("workgroupBarrier"),
"Missing workgroup barrier"
);
}
#[test]
fn test_f062_wgsl_has_bindings() {
let kernel = Lz4WarpCompressKernel::new(100);
let wgsl = kernel.emit_wgsl();
assert!(
wgsl.contains("@group(0) @binding(0)"),
"Missing input binding"
);
assert!(
wgsl.contains("@group(0) @binding(1)"),
"Missing output binding"
);
assert!(
wgsl.contains("@group(0) @binding(2)"),
"Missing sizes binding"
);
}
#[test]
fn test_f062_wgsl_has_shared_memory() {
let kernel = Lz4WarpCompressKernel::new(100);
let wgsl = kernel.emit_wgsl();
assert!(
wgsl.contains("var<workgroup>"),
"Missing workgroup shared memory"
);
}
#[test]
fn test_f063_wgsl_batch_size_embedded() {
let kernel = Lz4WarpCompressKernel::new(500);
let wgsl = kernel.emit_wgsl();
assert!(
wgsl.contains("500u"),
"Batch size should be embedded in WGSL"
);
}
#[test]
fn test_f063_wgsl_has_entry_point() {
let kernel = Lz4WarpCompressKernel::new(100);
let wgsl = kernel.emit_wgsl();
assert!(
wgsl.contains("fn lz4_compress_warp"),
"Missing entry point function"
);
}
#[test]
fn test_f064_wgsl_has_builtins() {
let kernel = Lz4WarpCompressKernel::new(100);
let wgsl = kernel.emit_wgsl();
assert!(
wgsl.contains("@builtin(workgroup_id)"),
"Missing workgroup_id builtin"
);
assert!(
wgsl.contains("@builtin(local_invocation_id)"),
"Missing local_invocation_id builtin"
);
}
#[test]
fn test_f064_dual_backend_consistency() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
let wgsl = kernel.emit_wgsl();
assert!(
ptx.contains("bar.sync") || ptx.contains("barrier"),
"PTX missing barrier"
);
assert!(wgsl.contains("workgroupBarrier"), "WGSL missing barrier");
assert!(ptx.contains("lz4_compress_warp"));
assert!(wgsl.contains("lz4_compress_warp"));
}
#[test]
fn test_f036_ptx_has_zero_page_detection() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("or.b32"),
"Missing OR operations for zero detection"
);
assert!(
ptx.contains("L_write_zero_size"),
"Missing zero-size output path"
);
assert!(
ptx.contains("L_after_size_write"),
"Missing size write merge label"
);
}
#[test]
fn test_f037_ptx_warp_reduction() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
let bar_count = ptx.matches("bar.sync").count();
assert!(
bar_count >= 3,
"Should have at least 3 barrier syncs, found {}",
bar_count
);
}
#[test]
fn test_f038_zero_page_compressed_size() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("20"),
"Should reference compressed zero page size"
);
}
#[test]
fn test_f039_page_id_calculation() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("%ctaid.x"), "Missing blockIdx.x access");
assert!(ptx.contains("%tid.x"), "Missing threadIdx.x access");
}
#[test]
fn test_f040_lane_id_masking() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("and.b32"), "Missing lane ID masking");
}
#[test]
fn test_f041_shared_memory_allocation() {
let kernel = Lz4WarpCompressKernel::new(100);
let smem = kernel.shared_memory_bytes();
let min_required = 4 * (PAGE_SIZE as usize + LZ4_HASH_SIZE as usize * 2);
assert!(
smem >= min_required,
"Shared memory {} < required {}",
smem,
min_required
);
}
#[test]
fn test_f042_bounds_check_present() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("setp.lt"),
"Missing bounds check comparison (setp.lt)"
);
assert!(ptx.contains("L_exit"), "Missing exit label for OOB pages");
}
#[test]
fn test_f043_cooperative_load() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
let ld_count = ptx.matches("ld.global.u32").count();
assert!(
ld_count >= 32,
"Should have many global loads, found {}",
ld_count
);
}
#[test]
fn test_f044_leader_thread_writes_size() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("setp.eq"), "Missing leader thread check");
assert!(
ptx.contains("L_not_leader"),
"Missing non-leader skip label"
);
}
#[test]
fn test_f045_output_size_write() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("st.global.u32"), "Missing size output store");
}
#[test]
fn test_f046_wgsl_zero_page_detection() {
let kernel = Lz4WarpCompressKernel::new(100);
let wgsl = kernel.emit_wgsl();
assert!(
wgsl.contains("thread_or = thread_or |"),
"Missing thread OR reduction"
);
assert!(
wgsl.contains("if (page_or == 0u)"),
"Missing zero page check"
);
assert!(wgsl.contains("20u"), "Missing compressed zero page size");
}
#[test]
fn test_f047_wgsl_reduction_barrier() {
let kernel = Lz4WarpCompressKernel::new(100);
let wgsl = kernel.emit_wgsl();
let barrier_count = wgsl.matches("workgroupBarrier()").count();
assert!(
barrier_count >= 3,
"Should have at least 3 barriers, found {}",
barrier_count
);
}
#[test]
fn test_f048_shared_memory_reduction() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
let wgsl = kernel.emit_wgsl();
assert!(
ptx.contains("st.u32"),
"PTX missing generic store for reduction"
);
assert!(
ptx.contains("ld.u32"),
"PTX missing generic load for reduction"
);
assert!(
ptx.contains(".shared"),
"PTX missing shared memory declaration"
);
assert!(
ptx.contains("cvta.shared"),
"PTX missing cvta for shared->generic"
);
assert!(
wgsl.contains("smem[reduction_idx]"),
"WGSL missing shared memory reduction"
);
}
#[test]
fn test_f049_page_data_integrity() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
let global_loads = ptx.matches("ld.global.u32").count();
let global_stores = ptx.matches("st.global.u32").count();
assert!(global_loads >= 32, "Need at least 32 global loads for 4KB");
assert!(
global_stores >= 32,
"Need at least 32 global stores for 4KB"
);
}
#[test]
fn test_f050_kernel_determinism() {
let k1 = Lz4WarpCompressKernel::new(100);
let k2 = Lz4WarpCompressKernel::new(100);
let wgsl1 = k1.emit_wgsl();
let wgsl2 = k2.emit_wgsl();
assert_eq!(wgsl1, wgsl2, "WGSL should be deterministic");
let ptx1 = k1.emit_ptx();
let ptx2 = k2.emit_ptx();
let instr_count_1 = ptx1
.lines()
.filter(|l| l.trim().starts_with(|c: char| c.is_alphabetic()))
.count();
let instr_count_2 = ptx2
.lines()
.filter(|l| l.trim().starts_with(|c: char| c.is_alphabetic()))
.count();
assert_eq!(
instr_count_1, instr_count_2,
"PTX instruction count should match"
);
assert_eq!(
ptx1.matches("L_exit").count(),
ptx2.matches("L_exit").count()
);
assert_eq!(
ptx1.matches("L_not_leader").count(),
ptx2.matches("L_not_leader").count()
);
}
#[test]
fn test_gpu_lz4_ptx_has_hash_table() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("0x9e3779b1") || ptx.contains("2654435761") || ptx.contains("hash"),
"PTX must have LZ4 hash computation (mul by 0x9E3779B1)"
);
}
#[test]
fn test_gpu_lz4_ptx_has_match_finding() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("match") || ptx.contains("L_found_match") || ptx.contains("L_check_match"),
"PTX must have match finding logic with labeled branches"
);
}
#[test]
fn test_gpu_lz4_ptx_has_sequence_encoding() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
assert!(
ptx.contains("token") || ptx.contains("L_encode") || ptx.contains("L_write_sequence"),
"PTX must have LZ4 sequence encoding logic"
);
}
#[test]
fn test_gpu_lz4_ptx_has_output_buffer_management() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
let has_dynamic_size =
ptx.contains("out_pos") || ptx.contains("L_compress") || ptx.contains("compressed_len");
assert!(
has_dynamic_size,
"PTX must track output buffer position dynamically for compression"
);
}
#[test]
fn test_gpu_lz4_kernel_has_compression_loop() {
let kernel = Lz4WarpCompressKernel::new(100);
let ptx = kernel.emit_ptx();
let has_compress_loop = ptx.contains("L_compress_loop")
|| ptx.contains("L_main_loop")
|| (ptx.contains("bra") && ptx.contains("L_loop"));
assert!(
has_compress_loop,
"GPU kernel must have main compression loop (L_compress_loop or similar)"
);
}
}