trueno-gpu 0.4.29

Pure Rust PTX generation for NVIDIA CUDA - no LLVM, no nvcc
Documentation
// Allow similar names for offset variables (imm4, imm8, imm12, etc. are intentionally named)
#![allow(clippy::similar_names)]

use super::Lz4WarpCompressKernel;

impl Lz4WarpCompressKernel {
    /// Emit WGSL shader code for WebGPU backend
    ///
    /// This generates equivalent functionality for cross-platform GPU compute.
    /// WGSL uses workgroups instead of CUDA blocks, and subgroups instead of warps.
    #[must_use]
    pub fn emit_wgsl(&self) -> String {
        format!(
            r"// LZ4 Warp-Cooperative Compression Kernel (WGSL)
// Generated by trueno-gpu - Pure Rust GPU code generation
// WebGPU cross-platform: Vulkan, Metal, DX12, WebGPU

// Constants
const PAGE_SIZE: u32 = 4096u;
const SUBGROUP_SIZE: u32 = 32u;
const PAGES_PER_WORKGROUP: u32 = 4u;

// Bindings
@group(0) @binding(0) var<storage, read> input_batch: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_batch: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_sizes: array<u32>;

// Workgroup shared memory (48KB per workgroup)
var<workgroup> smem: array<u32, 12288>;  // 48KB / 4 bytes

@compute @workgroup_size(128, 1, 1)
fn lz4_compress_warp(
    @builtin(workgroup_id) workgroup_id: vec3<u32>,
    @builtin(local_invocation_id) local_id: vec3<u32>,
    @builtin(num_workgroups) num_workgroups: vec3<u32>,
) {{
    let batch_size: u32 = {batch_size}u;

    // Calculate warp and lane IDs (WGSL uses subgroups)
    let thread_id = local_id.x;
    let warp_id = thread_id / SUBGROUP_SIZE;
    let lane_id = thread_id % SUBGROUP_SIZE;

    // Calculate page assignment
    let page_id = workgroup_id.x * PAGES_PER_WORKGROUP + warp_id;

    // Bounds check
    if (page_id >= batch_size) {{
        return;
    }}

    // Calculate memory offsets
    let page_offset = page_id * (PAGE_SIZE / 4u);  // In u32 units
    let load_base = lane_id * 32u;  // 128 bytes per thread = 32 u32s
    let smem_warp_base = warp_id * (PAGE_SIZE / 4u);

    // Phase 1: Cooperative load from global to shared memory
    for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
        let chunk_off = load_base + i * 8u;  // 8 u32s per iteration
        let global_idx = page_offset + chunk_off;
        let smem_idx = smem_warp_base + chunk_off;

        // Load 8 u32s (32 bytes)
        smem[smem_idx + 0u] = input_batch[global_idx + 0u];
        smem[smem_idx + 1u] = input_batch[global_idx + 1u];
        smem[smem_idx + 2u] = input_batch[global_idx + 2u];
        smem[smem_idx + 3u] = input_batch[global_idx + 3u];
        smem[smem_idx + 4u] = input_batch[global_idx + 4u];
        smem[smem_idx + 5u] = input_batch[global_idx + 5u];
        smem[smem_idx + 6u] = input_batch[global_idx + 6u];
        smem[smem_idx + 7u] = input_batch[global_idx + 7u];
    }}

    // Workgroup barrier
    workgroupBarrier();

    // Phase 2: Zero-page detection with parallel reduction
    // Each thread checks if its 128 bytes are all zeros
    var thread_or: u32 = 0u;
    for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
        let chunk_off = load_base + i * 8u;
        let smem_idx = smem_warp_base + chunk_off;

        thread_or = thread_or | smem[smem_idx + 0u];
        thread_or = thread_or | smem[smem_idx + 1u];
        thread_or = thread_or | smem[smem_idx + 2u];
        thread_or = thread_or | smem[smem_idx + 3u];
        thread_or = thread_or | smem[smem_idx + 4u];
        thread_or = thread_or | smem[smem_idx + 5u];
        thread_or = thread_or | smem[smem_idx + 6u];
        thread_or = thread_or | smem[smem_idx + 7u];
    }}

    // Store each thread's result for reduction
    let reduction_idx = smem_warp_base + 1024u + lane_id;  // Use space after page data
    smem[reduction_idx] = thread_or;
    workgroupBarrier();

    // Lane 0 reduces all 32 values and writes output size
    if (lane_id == 0u) {{
        var page_or: u32 = 0u;
        for (var j: u32 = 0u; j < SUBGROUP_SIZE; j = j + 1u) {{
            page_or = page_or | smem[smem_warp_base + 1024u + j];
        }}

        // Zero page: compressed to 20 bytes, non-zero: uncompressed
        if (page_or == 0u) {{
            output_sizes[page_id] = 20u;  // LZ4 minimal encoding
        }} else {{
            output_sizes[page_id] = PAGE_SIZE;
        }}
    }}

    workgroupBarrier();

    // Phase 3: Cooperative store from shared to global memory
    for (var i: u32 = 0u; i < 4u; i = i + 1u) {{
        let chunk_off = load_base + i * 8u;
        let global_idx = page_offset + chunk_off;
        let smem_idx = smem_warp_base + chunk_off;

        output_batch[global_idx + 0u] = smem[smem_idx + 0u];
        output_batch[global_idx + 1u] = smem[smem_idx + 1u];
        output_batch[global_idx + 2u] = smem[smem_idx + 2u];
        output_batch[global_idx + 3u] = smem[smem_idx + 3u];
        output_batch[global_idx + 4u] = smem[smem_idx + 4u];
        output_batch[global_idx + 5u] = smem[smem_idx + 5u];
        output_batch[global_idx + 6u] = smem[smem_idx + 6u];
        output_batch[global_idx + 7u] = smem[smem_idx + 7u];
    }}
}}
",
            batch_size = self.batch_size
        )
    }
}