use hanzo_kernel::prelude::*;
#[kernel(targets(cuda, metal, vulkan, webgpu, cpu), unchecked)]
fn dp4a_kernel(a: &Array<Vector<i32, Const<4>>>, b: &Array<Vector<i32, Const<4>>>, out: &mut Array<i32>) {
let i = ABSOLUTE_POS;
if i < out.len() {
out[i] = a[i].dot(b[i]);
}
}
fn main() {
use cubecl::wgpu::{WgpuDevice, WgpuRuntime};
let client = WgpuRuntime::client(&WgpuDevice::default());
let n = 8usize;
let a: Vec<i32> = (0..n * 4).map(|i| (i % 7) as i32 - 3).collect();
let b: Vec<i32> = (0..n * 4).map(|i| (i % 5) as i32 - 2).collect();
let ah = client.create_from_slice(i32::as_bytes(&a));
let bh = client.create_from_slice(i32::as_bytes(&b));
let oh = client.create_from_slice(i32::as_bytes(&vec![0i32; n]));
unsafe {
dp4a_kernel::launch_unchecked::<WgpuRuntime>(
&client, Grid::Static(1, 1, 1), Block::new_1d(n as u32),
ArrayArg::from_raw_parts(ah.clone(), n),
ArrayArg::from_raw_parts(bh.clone(), n),
ArrayArg::from_raw_parts(oh.clone(), n),
);
}
let bytes = client.read_one_unchecked(oh);
let out = i32::from_bytes(&bytes);
let refv: Vec<i32> = (0..n).map(|i| (0..4).map(|j| a[i*4+j]*b[i*4+j]).sum()).collect();
println!("dp4a out = {:?}", &out[..n]);
println!("ref = {:?}", refv);
println!("MATCH: {}", out[..n] == refv[..]);
use hanzo_kernel::quant::{matvec_q8_dp4a_ref, matvec_q8_dp4a_run, matvec_q8_bench, matvec_q8_ref};
let (rows, k) = (4096usize, 4096usize);
let nb = k / 32;
let mut seed = 0x9E3779B97F4A7C15u64;
let mut rnd = || { seed ^= seed << 13; seed ^= seed >> 7; seed ^= seed << 17; seed };
let wq: Vec<i32> = (0..rows * k).map(|_| (rnd() % 255) as i32 - 127).collect();
let xq: Vec<i32> = (0..k).map(|_| (rnd() % 255) as i32 - 127).collect();
let wd: Vec<f32> = (0..rows * nb).map(|_| (rnd() % 1000) as f32 / 8000.0 + 0.01).collect();
let reference = matvec_q8_dp4a_ref(&wq, &xq, &wd, rows, k);
let (got, ms_dp4a) = matvec_q8_dp4a_run::<WgpuRuntime>(&client, &wq, &xq, &wd, rows, k, 50);
let mut mr = 0f32;
for (a, b) in reference.iter().zip(got.iter()) { mr = mr.max((a - b).abs() / a.abs().max(1.0)); }
let xf: Vec<f32> = xq.iter().map(|&v| v as f32).collect();
let ms_scalar = matvec_q8_bench::<WgpuRuntime>(&client, &wd, &wq, &xf, rows, k, 50);
let _ = matvec_q8_ref(&wd, &wq, &xf, 1, k);
let gflops = |ms: f64| 2.0 * rows as f64 * k as f64 / (ms * 1e6);
let eff_gbps = |ms: f64| (rows * k) as f64 / (ms * 1e6);
println!("\n=== matvec {rows}x{k} (kernel-only, Vulkan gfx1151) ===");
println!("dp4a (Vector.dot/OpSDot): max_rel={:.2e} {} {:.3} ms {:.0} GFLOP/s {:.0} GB/s(int8-eff)",
mr, if mr < 2e-2 { "BIT-EXACT ✓" } else { "MISMATCH ✗" }, ms_dp4a, gflops(ms_dp4a), eff_gbps(ms_dp4a));
println!("scalar (i32 mul-add): {:.3} ms {:.0} GFLOP/s {:.0} GB/s(int8-eff)",
ms_scalar, gflops(ms_scalar), eff_gbps(ms_scalar));
println!("dp4a speedup over scalar: {:.2}x", ms_scalar / ms_dp4a);
}