#![allow(clippy::similar_names)]
use super::Lz4WarpCompressKernel;
impl Lz4WarpCompressKernel {
#[must_use]
pub fn emit_wgsl(&self) -> String {
format!(
r"// LZ4 Warp-Cooperative Compression Kernel (WGSL)
// Generated by trueno-gpu - Pure Rust GPU code generation
// WebGPU cross-platform: Vulkan, Metal, DX12, WebGPU
// Constants
const PAGE_SIZE: u32 = 4096u;
const SUBGROUP_SIZE: u32 = 32u;
const PAGES_PER_WORKGROUP: u32 = 4u;
// Bindings
@group(0) @binding(0) var<storage, read> input_batch: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_batch: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_sizes: array<u32>;
// Workgroup shared memory (48KB per workgroup)
var<workgroup> smem: array<u32, 12288>; // 48KB / 4 bytes
@compute @workgroup_size(128, 1, 1)
fn lz4_compress_warp(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {{
let batch_size: u32 = {batch_size}u;
// Calculate warp and lane IDs (WGSL uses subgroups)
let thread_id = local_id.x;
let warp_id = thread_id / SUBGROUP_SIZE;
let lane_id = thread_id % SUBGROUP_SIZE;
// Calculate page assignment
let page_id = workgroup_id.x * PAGES_PER_WORKGROUP + warp_id;
// Bounds check
if (page_id >= batch_size) {{
return;
}}
// Calculate memory offsets
let page_offset = page_id * (PAGE_SIZE / 4u); // In u32 units
let load_base = lane_id * 32u; // 128 bytes per thread = 32 u32s
let smem_warp_base = warp_id * (PAGE_SIZE / 4u);
// Phase 1: Cooperative load from global to shared memory
for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
let chunk_off = load_base + i * 8u; // 8 u32s per iteration
let global_idx = page_offset + chunk_off;
let smem_idx = smem_warp_base + chunk_off;
// Load 8 u32s (32 bytes)
smem[smem_idx + 0u] = input_batch[global_idx + 0u];
smem[smem_idx + 1u] = input_batch[global_idx + 1u];
smem[smem_idx + 2u] = input_batch[global_idx + 2u];
smem[smem_idx + 3u] = input_batch[global_idx + 3u];
smem[smem_idx + 4u] = input_batch[global_idx + 4u];
smem[smem_idx + 5u] = input_batch[global_idx + 5u];
smem[smem_idx + 6u] = input_batch[global_idx + 6u];
smem[smem_idx + 7u] = input_batch[global_idx + 7u];
}}
// Workgroup barrier
workgroupBarrier();
// Phase 2: Zero-page detection with parallel reduction
// Each thread checks if its 128 bytes are all zeros
var thread_or: u32 = 0u;
for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
let chunk_off = load_base + i * 8u;
let smem_idx = smem_warp_base + chunk_off;
thread_or = thread_or | smem[smem_idx + 0u];
thread_or = thread_or | smem[smem_idx + 1u];
thread_or = thread_or | smem[smem_idx + 2u];
thread_or = thread_or | smem[smem_idx + 3u];
thread_or = thread_or | smem[smem_idx + 4u];
thread_or = thread_or | smem[smem_idx + 5u];
thread_or = thread_or | smem[smem_idx + 6u];
thread_or = thread_or | smem[smem_idx + 7u];
}}
// Store each thread's result for reduction
let reduction_idx = smem_warp_base + 1024u + lane_id; // Use space after page data
smem[reduction_idx] = thread_or;
workgroupBarrier();
// Lane 0 reduces all 32 values and writes output size
if (lane_id == 0u) {{
var page_or: u32 = 0u;
for (var j: u32 = 0u; j < SUBGROUP_SIZE; j = j + 1u) {{
page_or = page_or | smem[smem_warp_base + 1024u + j];
}}
// Zero page: compressed to 20 bytes, non-zero: uncompressed
if (page_or == 0u) {{
output_sizes[page_id] = 20u; // LZ4 minimal encoding
}} else {{
output_sizes[page_id] = PAGE_SIZE;
}}
}}
workgroupBarrier();
// Phase 3: Cooperative store from shared to global memory
for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
let chunk_off = load_base + i * 8u;
let global_idx = page_offset + chunk_off;
let smem_idx = smem_warp_base + chunk_off;
output_batch[global_idx + 0u] = smem[smem_idx + 0u];
output_batch[global_idx + 1u] = smem[smem_idx + 1u];
output_batch[global_idx + 2u] = smem[smem_idx + 2u];
output_batch[global_idx + 3u] = smem[smem_idx + 3u];
output_batch[global_idx + 4u] = smem[smem_idx + 4u];
output_batch[global_idx + 5u] = smem[smem_idx + 5u];
output_batch[global_idx + 6u] = smem[smem_idx + 6u];
output_batch[global_idx + 7u] = smem[smem_idx + 7u];
}}
}}
",
batch_size = self.batch_size
)
}
}