#![cfg(feature = "vulkan")]
use hanzo_ml::quantized::{GgmlDType, QMatMul, QStorage, QTensor};
use hanzo_ml::{Device, Module, Tensor};
use std::sync::Arc;
fn pseudo(i: usize) -> f32 {
let mut z = (i as u64).wrapping_add(0x9E37_79B9_7F4A_7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^= z >> 31;
((z >> 40) as f32 / (1u32 << 24) as f32) * 2.0 - 1.0
}
struct ErrStats {
max_abs: f32,
max_rel: f32,
rms: f32,
quant_max_abs: f32, }
fn run_case(dev: &Device, dtype: GgmlDType, nout: usize, k: usize) -> hanzo_ml::Result<ErrStats> {
let w_host: Vec<f32> = (0..nout * k).map(|i| pseudo(i) * 0.5).collect();
let x_host: Vec<f32> = (0..k).map(|i| pseudo(i + 1_000_003)).collect();
let cpu = Device::Cpu;
let w_t = Tensor::from_vec(w_host.clone(), (nout, k), &cpu)?;
let q = QTensor::quantize(&w_t, dtype)?;
let w_deq: Vec<f32> = q.dequantize(&cpu)?.flatten_all()?.to_vec1::<f32>()?;
let raw = q.data()?;
let vk = dev.as_vulkan_device()?;
let wq = vk.upload_qweight(&raw)?;
let y_gpu: Vec<f32> = match dtype {
GgmlDType::Q4_0 => vk.matvec_q4_0(&wq, &x_host, nout, k)?,
GgmlDType::Q8_0 => vk.matvec_q8_0(&wq, &x_host, nout, k)?,
GgmlDType::Q4K => vk.matvec_q4k(&wq, &x_host, nout, k)?,
_ => panic!("unsupported dtype in run_case: {dtype:?}"),
};
assert_eq!(y_gpu.len(), nout);
let mut max_abs = 0f32;
let mut sse = 0f64;
let mut ref_sq = 0f64; let mut quant_max_abs = 0f32;
for n in 0..nout {
let mut ref_deq = 0f64;
let mut ref_orig = 0f64;
for j in 0..k {
ref_deq += w_deq[n * k + j] as f64 * x_host[j] as f64;
ref_orig += w_host[n * k + j] as f64 * x_host[j] as f64;
}
let g = y_gpu[n] as f64;
max_abs = max_abs.max((g - ref_deq).abs() as f32);
sse += (g - ref_deq) * (g - ref_deq);
ref_sq += ref_deq * ref_deq;
quant_max_abs = quant_max_abs.max((g - ref_orig).abs() as f32);
}
let ref_rms = (ref_sq / nout as f64).sqrt();
let err_rms = (sse / nout as f64).sqrt();
let max_rel = (err_rms / ref_rms.max(1e-9)) as f32;
Ok(ErrStats {
max_abs,
max_rel,
rms: err_rms as f32,
quant_max_abs,
})
}
fn gpu() -> Option<Device> {
match Device::new_vulkan(0) {
Ok(d) => Some(d),
Err(e) => {
eprintln!("[vulkan_quant_tests] no Vulkan GPU ({e}); skipping");
None
}
}
}
const SHAPES: &[(usize, usize)] = &[
(2048, 2048), (4096, 2048), (2048, 4096), (512, 256), ];
#[test]
fn vulkan_matvec_q4_0_matches_cpu() -> hanzo_ml::Result<()> {
let Some(dev) = gpu() else { return Ok(()) };
for &(nout, k) in SHAPES {
let s = run_case(&dev, GgmlDType::Q4_0, nout, k)?;
println!(
"Q4_0 nout={nout:5} k={k:5} max_abs={:.3e} max_rel={:.3e} rms={:.3e} (quant err vs f32: {:.3e})",
s.max_abs, s.max_rel, s.rms, s.quant_max_abs
);
assert!(
s.max_rel < 1e-3 && s.max_abs < 1e-3,
"Q4_0 GPU/CPU mismatch too large: max_abs={} max_rel={}",
s.max_abs,
s.max_rel
);
}
Ok(())
}
#[test]
fn vulkan_matvec_q8_0_matches_cpu() -> hanzo_ml::Result<()> {
let Some(dev) = gpu() else { return Ok(()) };
for &(nout, k) in SHAPES {
let s = run_case(&dev, GgmlDType::Q8_0, nout, k)?;
println!(
"Q8_0 nout={nout:5} k={k:5} max_abs={:.3e} max_rel={:.3e} rms={:.3e} (quant err vs f32: {:.3e})",
s.max_abs, s.max_rel, s.rms, s.quant_max_abs
);
assert!(
s.max_rel < 1e-3 && s.max_abs < 1e-3,
"Q8_0 GPU/CPU mismatch too large: max_abs={} max_rel={}",
s.max_abs,
s.max_rel
);
}
Ok(())
}
#[test]
fn vulkan_matvec_q4k_matches_cpu() -> hanzo_ml::Result<()> {
let Some(dev) = gpu() else { return Ok(()) };
for &(nout, k) in SHAPES {
let s = run_case(&dev, GgmlDType::Q4K, nout, k)?;
println!(
"Q4_K nout={nout:5} k={k:5} max_abs={:.3e} max_rel={:.3e} rms={:.3e} (quant err vs f32: {:.3e})",
s.max_abs, s.max_rel, s.rms, s.quant_max_abs
);
assert!(
s.max_rel < 1e-3 && s.max_abs < 1e-3,
"Q4_K GPU/CPU mismatch too large: max_abs={} max_rel={}",
s.max_abs,
s.max_rel
);
}
Ok(())
}
fn end_to_end_case(dev: &Device, dtype: GgmlDType, nout: usize, k: usize) -> hanzo_ml::Result<f32> {
let cpu = Device::Cpu;
let w_host: Vec<f32> = (0..nout * k).map(|i| pseudo(i) * 0.5).collect();
let x_host: Vec<f32> = (0..k).map(|i| pseudo(i + 7)).collect();
let w_t = Tensor::from_vec(w_host, (nout, k), &cpu)?;
let q_cpu = Arc::new(QTensor::quantize(&w_t, dtype)?);
let bytes = q_cpu.data()?.into_owned();
let w_deq: Vec<f32> = q_cpu.dequantize(&cpu)?.flatten_all()?.to_vec1::<f32>()?;
let mut y_ref = vec![0f64; nout];
for n in 0..nout {
let mut acc = 0f64;
for j in 0..k {
acc += w_deq[n * k + j] as f64 * x_host[j] as f64;
}
y_ref[n] = acc;
}
let qs_vk = QStorage::from_data(std::borrow::Cow::Owned(bytes), dev, dtype)?;
let q_vk = QTensor::new(qs_vk, (nout, k))?;
let qm_vk = QMatMul::from_qtensor(q_vk)?;
let x_vk = Tensor::from_vec(x_host, (1, k), dev)?;
let y_vk = qm_vk.forward(&x_vk)?.flatten_all()?.to_vec1::<f32>()?;
assert_eq!(y_vk.len(), nout);
let mut max_abs = 0f32;
for n in 0..nout {
max_abs = max_abs.max((y_vk[n] as f64 - y_ref[n]).abs() as f32);
}
Ok(max_abs)
}
#[test]
fn vulkan_qmatmul_forward_matches_cpu() -> hanzo_ml::Result<()> {
let Some(dev) = gpu() else { return Ok(()) };
for &(nout, k) in &[(2048usize, 2048usize), (4096, 2048), (512, 256)] {
for dt in [GgmlDType::Q4_0, GgmlDType::Q8_0, GgmlDType::Q4K] {
let max_abs = end_to_end_case(&dev, dt, nout, k)?;
println!("QMatMul::forward {dt:?} nout={nout:5} k={k:5} GPU-vs-(dequant ref) max_abs={max_abs:.3e}");
assert!(
max_abs < 1e-3,
"QMatMul::forward {dt:?} GPU/ref mismatch too large: {max_abs}"
);
}
}
Ok(())
}
fn moe_case(
dev: &Device,
dtype: GgmlDType,
e_cnt: usize,
n: usize,
k: usize,
t: usize,
topk: usize,
) -> hanzo_ml::Result<f32> {
let cpu = Device::Cpu;
let bank_host: Vec<f32> = (0..e_cnt * n * k).map(|i| pseudo(i) * 0.5).collect();
let bank_t = Tensor::from_vec(bank_host, (e_cnt, n, k), &cpu)?;
let q_bank = QTensor::quantize(&bank_t.reshape((e_cnt * n, k))?, dtype)?; let bank_deq: Vec<f32> = q_bank.dequantize(&cpu)?.flatten_all()?.to_vec1::<f32>()?; let bytes = q_bank.data()?.into_owned();
let x_host: Vec<f32> = (0..t * topk * k).map(|i| pseudo(i + 11) * 0.7).collect();
let ids_host: Vec<u32> = (0..t * topk).map(|i| ((i * 7 + 3) % e_cnt) as u32).collect();
let qs_vk = QStorage::from_data(std::borrow::Cow::Owned(bytes), dev, dtype)?;
let q_vk = QTensor::new(qs_vk, (e_cnt, n, k))?;
let x_vk = Tensor::from_vec(x_host.clone(), (t, topk, k), dev)?;
let ids_vk = Tensor::from_vec(ids_host.clone(), (t, topk), dev)?;
let y_vk = q_vk
.indexed_moe_forward(&x_vk, &ids_vk)?
.reshape((t * topk, n))?
.to_vec2::<f32>()?;
let mut max_abs = 0f32;
for slot in 0..t * topk {
let e = ids_host[slot] as usize;
for r in 0..n {
let wbase = (e * n + r) * k;
let xbase = slot * k;
let mut acc = 0f64;
for j in 0..k {
acc += bank_deq[wbase + j] as f64 * x_host[xbase + j] as f64;
}
max_abs = max_abs.max((y_vk[slot][r] as f64 - acc).abs() as f32);
}
}
Ok(max_abs)
}
fn flash_case(
dev: &Device,
bh: usize,
lq: usize,
lk: usize,
d: usize,
causal: bool,
) -> hanzo_ml::Result<f32> {
let scale = 1.0f32 / (d as f32).sqrt();
let q: Vec<f32> = (0..bh * lq * d).map(|i| pseudo(i) * 0.5).collect();
let k: Vec<f32> = (0..bh * lk * d).map(|i| pseudo(i + 5) * 0.5).collect();
let v: Vec<f32> = (0..bh * lk * d).map(|i| pseudo(i + 9) * 0.5).collect();
let vk = dev.as_vulkan_device()?;
let out = vk.flash_attn(&q, &k, &v, bh, lq, lk, d, scale, causal)?;
assert_eq!(out.len(), bh * lq * d);
let mut max_abs = 0f32;
for b in 0..bh {
for qi in 0..lq {
let last = if causal { qi + (lk - lq) + 1 } else { lk };
let last = last.min(lk);
let mut sc = vec![0f64; last];
let mut mx = f64::NEG_INFINITY;
for (j, scj) in sc.iter_mut().enumerate() {
let mut s = 0f64;
for t in 0..d {
s += q[(b * lq + qi) * d + t] as f64 * k[(b * lk + j) * d + t] as f64;
}
*scj = s * scale as f64;
mx = mx.max(*scj);
}
let mut denom = 0f64;
for scj in sc.iter_mut() {
*scj = (*scj - mx).exp();
denom += *scj;
}
for t in 0..d {
let mut acc = 0f64;
for (j, &p) in sc.iter().enumerate() {
acc += p * v[(b * lk + j) * d + t] as f64;
}
acc /= denom;
let g = out[(b * lq + qi) * d + t] as f64;
max_abs = max_abs.max((g - acc).abs() as f32);
}
}
}
Ok(max_abs)
}
#[test]
fn vulkan_flash_attn_matches_cpu() -> hanzo_ml::Result<()> {
let Some(dev) = gpu() else { return Ok(()) };
let cases = [
(8usize, 16usize, 16usize, 128usize, false), (8, 16, 16, 128, true), (8, 1, 64, 128, true), (4, 7, 13, 64, false), ];
for &(bh, lq, lk, d, causal) in &cases {
let max_abs = flash_case(&dev, bh, lq, lk, d, causal)?;
println!(
"FlashAttn bh={bh} lq={lq:3} lk={lk:3} d={d:3} causal={causal} GPU-vs-(eager ref) max_abs={max_abs:.3e}"
);
assert!(
max_abs < 1e-4,
"FlashAttn GPU/ref mismatch too large: {max_abs}"
);
}
Ok(())
}
#[test]
fn vulkan_moe_forward_matches_cpu() -> hanzo_ml::Result<()> {
let Some(dev) = gpu() else { return Ok(()) };
let cases = [
(8usize, 256usize, 256usize, 2usize, 2usize),
(16, 512, 256, 3, 4),
(4, 768, 512, 1, 2), ];
for &(e_cnt, n, k, t, topk) in &cases {
for dt in [GgmlDType::Q4_0, GgmlDType::Q8_0, GgmlDType::Q4K] {
let max_abs = moe_case(&dev, dt, e_cnt, n, k, t, topk)?;
println!(
"MoE {dt:?} E={e_cnt:3} n={n:4} k={k:4} t={t} topk={topk} GPU-vs-(dequant ref) max_abs={max_abs:.3e}"
);
assert!(
max_abs < 1e-3,
"MoE {dt:?} GPU/ref mismatch too large: {max_abs}"
);
}
}
Ok(())
}