hanzo-kernel 0.2.5

Hanzo's first-party GPU kernel DSL: one Rust source, lowered to CUDA/ROCm/Vulkan/Metal.
// dp4a lowering probe: does Vector<i32,4>.dot emit a hardware integer dot on Vulkan?
use hanzo_kernel::prelude::*;

#[kernel(targets(cuda, metal, vulkan, webgpu, cpu), unchecked)]
fn dp4a_kernel(a: &Array<Vector<i32, Const<4>>>, b: &Array<Vector<i32, Const<4>>>, out: &mut Array<i32>) {
    let i = ABSOLUTE_POS;
    if i < out.len() {
        out[i] = a[i].dot(b[i]);
    }
}

fn main() {
    use cubecl::wgpu::{WgpuDevice, WgpuRuntime};
    let client = WgpuRuntime::client(&WgpuDevice::default());
    let n = 8usize;
    let a: Vec<i32> = (0..n * 4).map(|i| (i % 7) as i32 - 3).collect();
    let b: Vec<i32> = (0..n * 4).map(|i| (i % 5) as i32 - 2).collect();
    let ah = client.create_from_slice(i32::as_bytes(&a));
    let bh = client.create_from_slice(i32::as_bytes(&b));
    let oh = client.create_from_slice(i32::as_bytes(&vec![0i32; n]));
    unsafe {
        dp4a_kernel::launch_unchecked::<WgpuRuntime>(
            &client, Grid::Static(1, 1, 1), Block::new_1d(n as u32),
            ArrayArg::from_raw_parts(ah.clone(), n),
            ArrayArg::from_raw_parts(bh.clone(), n),
            ArrayArg::from_raw_parts(oh.clone(), n),
        );
    }
    let bytes = client.read_one_unchecked(oh);
    let out = i32::from_bytes(&bytes);
    // reference
    let refv: Vec<i32> = (0..n).map(|i| (0..4).map(|j| a[i*4+j]*b[i*4+j]).sum()).collect();
    println!("dp4a out = {:?}", &out[..n]);
    println!("ref      = {:?}", refv);
    println!("MATCH: {}", out[..n] == refv[..]);

    // ---- matvec dp4a vs scalar on the same data (kernel-only throughput) ----
    use hanzo_kernel::quant::{matvec_q8_dp4a_ref, matvec_q8_dp4a_run, matvec_q8_bench, matvec_q8_ref};
    let (rows, k) = (4096usize, 4096usize);
    let nb = k / 32;
    let mut seed = 0x9E3779B97F4A7C15u64;
    let mut rnd = || { seed ^= seed << 13; seed ^= seed >> 7; seed ^= seed << 17; seed };
    let wq: Vec<i32> = (0..rows * k).map(|_| (rnd() % 255) as i32 - 127).collect();
    let xq: Vec<i32> = (0..k).map(|_| (rnd() % 255) as i32 - 127).collect();
    let wd: Vec<f32> = (0..rows * nb).map(|_| (rnd() % 1000) as f32 / 8000.0 + 0.01).collect();
    let reference = matvec_q8_dp4a_ref(&wq, &xq, &wd, rows, k);
    let (got, ms_dp4a) = matvec_q8_dp4a_run::<WgpuRuntime>(&client, &wq, &xq, &wd, rows, k, 50);
    let mut mr = 0f32;
    for (a, b) in reference.iter().zip(got.iter()) { mr = mr.max((a - b).abs() / a.abs().max(1.0)); }
    // scalar baseline on the SAME data (xq as f32)
    let xf: Vec<f32> = xq.iter().map(|&v| v as f32).collect();
    let ms_scalar = matvec_q8_bench::<WgpuRuntime>(&client, &wd, &wq, &xf, rows, k, 50);
    let _ = matvec_q8_ref(&wd, &wq, &xf, 1, k);
    let gflops = |ms: f64| 2.0 * rows as f64 * k as f64 / (ms * 1e6);
    // effective int8 weight BW (as if packed 1 byte/weight -- the hand-tuned accounting)
    let eff_gbps = |ms: f64| (rows * k) as f64 / (ms * 1e6);
    println!("\n=== matvec {rows}x{k} (kernel-only, Vulkan gfx1151) ===");
    println!("dp4a (Vector.dot/OpSDot): max_rel={:.2e} {}  {:.3} ms  {:.0} GFLOP/s  {:.0} GB/s(int8-eff)",
        mr, if mr < 2e-2 { "BIT-EXACT ✓" } else { "MISMATCH ✗" }, ms_dp4a, gflops(ms_dp4a), eff_gbps(ms_dp4a));
    println!("scalar (i32 mul-add):                        {:.3} ms  {:.0} GFLOP/s  {:.0} GB/s(int8-eff)",
        ms_scalar, gflops(ms_scalar), eff_gbps(ms_scalar));
    println!("dp4a speedup over scalar: {:.2}x", ms_scalar / ms_dp4a);
}

// appended: matvec dp4a vs scalar benchmark (called from a second entry below is not possible in one
// main; instead we inline into main via a module fn the build will include through the same file).