use crate::prelude::*;
pub const QK8_0: usize = 32;
#[kernel(targets(cuda, metal, vulkan, webgpu, cpu), unchecked)]
pub fn matvec_q8<F: Float>(
wd: &Array<F>,
wq: &Array<i32>,
x: &Array<F>,
out: &mut Array<F>,
#[comptime] k: usize,
) {
let row = ABSOLUTE_POS;
if row < out.len() {
let nb = k / 32;
let wbase = row * k;
let dbase = row * nb;
let mut acc = F::new(0.0);
for i in 0..k {
let d = wd[dbase + i / 32];
let q = F::cast_from(wq[wbase + i]);
acc += d * q * x[i];
}
out[row] = acc;
}
}
pub fn matvec_q8_run<R: Runtime>(
client: &ComputeClient<R>,
wd: &[f32],
wq: &[i32],
x: &[f32],
rows: usize,
k: usize,
) -> Vec<f32> {
let wdh = client.create_from_slice(f32::as_bytes(wd));
let wqh = client.create_from_slice(i32::as_bytes(wq));
let xh = client.create_from_slice(f32::as_bytes(x));
let oh = client.create_from_slice(f32::as_bytes(&vec![0.0f32; rows]));
let block = 64u32;
let grid = (rows as u32).div_ceil(block);
unsafe {
matvec_q8::launch_unchecked::<f32, R>(
client,
Grid::Static(grid, 1, 1),
Block::new_1d(block),
ArrayArg::from_raw_parts(wdh.clone(), wd.len()),
ArrayArg::from_raw_parts(wqh.clone(), wq.len()),
ArrayArg::from_raw_parts(xh.clone(), x.len()),
ArrayArg::from_raw_parts(oh.clone(), rows),
k,
);
}
let bytes = client.read_one_unchecked(oh);
f32::from_bytes(&bytes).to_vec()
}
pub fn matvec_q8_bench<R: Runtime>(
client: &ComputeClient<R>,
wd: &[f32],
wq: &[i32],
x: &[f32],
rows: usize,
k: usize,
iters: usize,
) -> f64 {
let wdh = client.create_from_slice(f32::as_bytes(wd));
let wqh = client.create_from_slice(i32::as_bytes(wq));
let xh = client.create_from_slice(f32::as_bytes(x));
let oh = client.create_from_slice(f32::as_bytes(&vec![0.0f32; rows]));
let block = 64u32;
let grid = (rows as u32).div_ceil(block);
let launch = |c: &ComputeClient<R>| unsafe {
matvec_q8::launch_unchecked::<f32, R>(
c,
Grid::Static(grid, 1, 1),
Block::new_1d(block),
ArrayArg::from_raw_parts(wdh.clone(), wd.len()),
ArrayArg::from_raw_parts(wqh.clone(), wq.len()),
ArrayArg::from_raw_parts(xh.clone(), x.len()),
ArrayArg::from_raw_parts(oh.clone(), rows),
k,
);
};
for _ in 0..3 {
launch(client);
}
let _ = client.read_one_unchecked(oh.clone()); let t = std::time::Instant::now();
for _ in 0..iters {
launch(client);
}
let _ = client.read_one_unchecked(oh); t.elapsed().as_secs_f64() * 1e3 / iters as f64
}
pub fn matvec_q8_ref(wd: &[f32], wq: &[i32], x: &[f32], rows: usize, k: usize) -> Vec<f32> {
let nb = k / 32;
(0..rows)
.map(|row| {
let mut acc = 0.0f32;
for i in 0..k {
acc += wd[row * nb + i / 32] * wq[row * k + i] as f32 * x[i];
}
acc
})
.collect()
}
pub const QK_K: usize = 256;
#[device]
fn byte_at(a: &Array<u32>, base: usize, i: usize) -> u32 {
(a[base + i / 4] >> ((8 * (i % 4)) as u32)) & 255
}
#[device]
fn q4k_sc(wsc: &Array<u32>, scbase: usize, j: usize) -> u32 {
let mut r = byte_at(wsc, scbase, j) & 63;
if j >= 4 {
r = (byte_at(wsc, scbase, j + 4) & 15) | ((byte_at(wsc, scbase, j - 4) >> 6) << 4);
}
r
}
#[device]
fn q4k_m(wsc: &Array<u32>, scbase: usize, j: usize) -> u32 {
let mut r = byte_at(wsc, scbase, j + 4) & 63;
if j >= 4 {
r = (byte_at(wsc, scbase, j + 4) >> 4) | ((byte_at(wsc, scbase, j) >> 6) << 4);
}
r
}
#[kernel(targets(cuda, metal, vulkan, webgpu, cpu), unchecked)]
pub fn matvec_q4k<F: Float>(
wqs: &Array<u32>,
wsc: &Array<u32>,
wd: &Array<F>,
wdm: &Array<F>,
x: &Array<F>,
out: &mut Array<F>,
#[comptime] k: usize,
) {
let row = ABSOLUTE_POS;
if row < out.len() {
let nb = k / 256;
let mut acc = F::new(0.0);
for b in 0..nb {
let blk = row * nb + b;
let qbase = blk * 32;
let scbase = blk * 3;
let d = wd[blk];
let dmin = wdm[blk];
let xbase = b * 256;
for g in 0..4 {
let is = g * 2;
let d1 = d * F::cast_from(q4k_sc(wsc, scbase, is));
let mm1 = dmin * F::cast_from(q4k_m(wsc, scbase, is));
let d2 = d * F::cast_from(q4k_sc(wsc, scbase, is + 1));
let mm2 = dmin * F::cast_from(q4k_m(wsc, scbase, is + 1));
for qi in 0..32 {
let qb = byte_at(wqs, qbase, g * 32 + qi);
let wlo = d1 * F::cast_from(qb & 15) - mm1;
acc += wlo * x[xbase + g * 64 + qi];
let whi = d2 * F::cast_from(qb >> 4) - mm2;
acc += whi * x[xbase + g * 64 + 32 + qi];
}
}
}
out[row] = acc;
}
}
pub fn matvec_q4k_run<R: Runtime>(
client: &ComputeClient<R>,
wqs: &[u32], wsc: &[u32], wd: &[f32], wdm: &[f32], x: &[f32], rows: usize, k: usize,
) -> Vec<f32> {
let qh = client.create_from_slice(u32::as_bytes(wqs));
let sh = client.create_from_slice(u32::as_bytes(wsc));
let dh = client.create_from_slice(f32::as_bytes(wd));
let mh = client.create_from_slice(f32::as_bytes(wdm));
let xh = client.create_from_slice(f32::as_bytes(x));
let oh = client.create_from_slice(f32::as_bytes(&vec![0.0f32; rows]));
let block = 64u32;
let grid = (rows as u32).div_ceil(block);
unsafe {
matvec_q4k::launch_unchecked::<f32, R>(
client, Grid::Static(grid, 1, 1), Block::new_1d(block),
ArrayArg::from_raw_parts(qh.clone(), wqs.len()),
ArrayArg::from_raw_parts(sh.clone(), wsc.len()),
ArrayArg::from_raw_parts(dh.clone(), wd.len()),
ArrayArg::from_raw_parts(mh.clone(), wdm.len()),
ArrayArg::from_raw_parts(xh.clone(), x.len()),
ArrayArg::from_raw_parts(oh.clone(), rows),
k,
);
}
f32::from_bytes(&client.read_one_unchecked(oh)).to_vec()
}
pub fn matvec_q4k_bench<R: Runtime>(
client: &ComputeClient<R>,
wqs: &[u32], wsc: &[u32], wd: &[f32], wdm: &[f32], x: &[f32], rows: usize, k: usize, iters: usize,
) -> f64 {
let qh = client.create_from_slice(u32::as_bytes(wqs));
let sh = client.create_from_slice(u32::as_bytes(wsc));
let dh = client.create_from_slice(f32::as_bytes(wd));
let mh = client.create_from_slice(f32::as_bytes(wdm));
let xh = client.create_from_slice(f32::as_bytes(x));
let oh = client.create_from_slice(f32::as_bytes(&vec![0.0f32; rows]));
let block = 64u32;
let grid = (rows as u32).div_ceil(block);
let launch = |c: &ComputeClient<R>| unsafe {
matvec_q4k::launch_unchecked::<f32, R>(
c, Grid::Static(grid, 1, 1), Block::new_1d(block),
ArrayArg::from_raw_parts(qh.clone(), wqs.len()),
ArrayArg::from_raw_parts(sh.clone(), wsc.len()),
ArrayArg::from_raw_parts(dh.clone(), wd.len()),
ArrayArg::from_raw_parts(mh.clone(), wdm.len()),
ArrayArg::from_raw_parts(xh.clone(), x.len()),
ArrayArg::from_raw_parts(oh.clone(), rows),
k,
);
};
for _ in 0..3 { launch(client); }
let _ = client.read_one_unchecked(oh.clone());
let t = std::time::Instant::now();
for _ in 0..iters { launch(client); }
let _ = client.read_one_unchecked(oh);
t.elapsed().as_secs_f64() * 1e3 / iters as f64
}
#[inline]
fn cpu_byte(a: &[u32], base: usize, i: usize) -> u32 { (a[base + i / 4] >> (8 * (i % 4))) & 255 }
#[inline]
fn cpu_sc(wsc: &[u32], scbase: usize, j: usize) -> u32 {
if j < 4 { cpu_byte(wsc, scbase, j) & 63 }
else { (cpu_byte(wsc, scbase, j + 4) & 15) | ((cpu_byte(wsc, scbase, j - 4) >> 6) << 4) }
}
#[inline]
fn cpu_m(wsc: &[u32], scbase: usize, j: usize) -> u32 {
if j < 4 { cpu_byte(wsc, scbase, j + 4) & 63 }
else { (cpu_byte(wsc, scbase, j + 4) >> 4) | ((cpu_byte(wsc, scbase, j) >> 6) << 4) }
}
pub fn matvec_q4k_ref(wqs: &[u32], wsc: &[u32], wd: &[f32], wdm: &[f32], x: &[f32], rows: usize, k: usize) -> Vec<f32> {
let nb = k / 256;
(0..rows).map(|row| {
let mut acc = 0.0f32;
for b in 0..nb {
let blk = row * nb + b;
let (qbase, scbase) = (blk * 32, blk * 3);
let (d, dmin) = (wd[blk], wdm[blk]);
let xbase = b * 256;
for g in 0..4 {
let is = g * 2;
let d1 = d * cpu_sc(wsc, scbase, is) as f32;
let mm1 = dmin * cpu_m(wsc, scbase, is) as f32;
let d2 = d * cpu_sc(wsc, scbase, is + 1) as f32;
let mm2 = dmin * cpu_m(wsc, scbase, is + 1) as f32;
for qi in 0..32 {
let qb = cpu_byte(wqs, qbase, g * 32 + qi);
acc += (d1 * (qb & 15) as f32 - mm1) * x[xbase + g * 64 + qi];
acc += (d2 * (qb >> 4) as f32 - mm2) * x[xbase + g * 64 + 32 + qi];
}
}
}
acc
}).collect()
}
pub fn gen_q4k(rows: usize, k: usize) -> (Vec<u32>, Vec<u32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let nb = k / 256;
let nblk = rows * nb;
let mut s = 0x9E3779B97F4A7C15u64;
let mut next = || { s ^= s << 13; s ^= s >> 7; s ^= s << 17; s };
let wqs: Vec<u32> = (0..nblk * 32).map(|_| next() as u32).collect(); let wsc: Vec<u32> = (0..nblk * 3).map(|_| next() as u32).collect(); let wd: Vec<f32> = (0..nblk).map(|_| half::f16::from_f32((next() % 1000) as f32 / 20000.0 + 0.002).to_f32()).collect();
let wdm: Vec<f32> = (0..nblk).map(|_| half::f16::from_f32((next() % 1000) as f32 / 40000.0).to_f32()).collect();
let x: Vec<f32> = (0..k).map(|_| (next() % 2000) as f32 / 1000.0 - 1.0).collect();
(wqs, wsc, wd, wdm, x)
}
#[kernel(targets(cuda, metal, vulkan, webgpu, cpu), unchecked)]
pub fn matvec_q8_dp4a<F: Float>(
wq: &Array<Vector<i32, Const<4>>>, xq: &Array<Vector<i32, Const<4>>>, wd: &Array<F>, out: &mut Array<F>,
#[comptime] k: usize,
) {
let row = ABSOLUTE_POS;
if row < out.len() {
let ng = k / 4;
let nb = k / 32;
let wbase = row * ng;
let dbase = row * nb;
let mut acc = F::new(0.0);
for g in 0..ng {
let dp = wq[wbase + g].dot(xq[g]); acc += wd[dbase + g / 8] * F::cast_from(dp);
}
out[row] = acc;
}
}
pub fn matvec_q8_dp4a_run<R: Runtime>(
client: &ComputeClient<R>, wq: &[i32], xq: &[i32], wd: &[f32], rows: usize, k: usize, bench_iters: usize,
) -> (Vec<f32>, f64) {
let wqh = client.create_from_slice(i32::as_bytes(wq));
let xqh = client.create_from_slice(i32::as_bytes(xq));
let wdh = client.create_from_slice(f32::as_bytes(wd));
let oh = client.create_from_slice(f32::as_bytes(&vec![0.0f32; rows]));
let block = 64u32; let grid = (rows as u32).div_ceil(block);
let ng = k / 4;
let launch = |c: &ComputeClient<R>| unsafe {
matvec_q8_dp4a::launch_unchecked::<f32, R>(
c, Grid::Static(grid, 1, 1), Block::new_1d(block),
ArrayArg::from_raw_parts(wqh.clone(), rows * ng),
ArrayArg::from_raw_parts(xqh.clone(), ng),
ArrayArg::from_raw_parts(wdh.clone(), wd.len()),
ArrayArg::from_raw_parts(oh.clone(), rows),
k,
);
};
launch(client);
let bytes = client.read_one_unchecked(oh.clone());
let out = f32::from_bytes(&bytes).to_vec();
for _ in 0..3 { launch(client); }
let _ = client.read_one_unchecked(oh.clone());
let t = std::time::Instant::now();
for _ in 0..bench_iters { launch(client); }
let _ = client.read_one_unchecked(oh);
let ms = t.elapsed().as_secs_f64() * 1e3 / bench_iters as f64;
(out, ms)
}
pub fn matvec_q8_dp4a_ref(wq: &[i32], xq: &[i32], wd: &[f32], rows: usize, k: usize) -> Vec<f32> {
let nb = k / 32;
(0..rows).map(|row| {
let mut acc = 0.0f32;
for i in 0..k { acc += wd[row * nb + i / 32] * (wq[row * k + i] * xq[i]) as f32; }
acc
}).collect()
}