use half::f16;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub const QK: usize = 32;
pub const Q4_0_BLOCK_BYTES: usize = 18;
pub const Q8_0_BLOCK_BYTES: usize = 34;
pub fn quantize_q4_0_block(x: &[f32]) -> [u8; Q4_0_BLOCK_BYTES] {
debug_assert_eq!(x.len(), QK);
let mut amax = 0.0f32;
let mut vmax = 0.0f32;
for &v in x {
if v.abs() > amax {
amax = v.abs();
vmax = v;
}
}
let d = vmax / -8.0;
let id = if d != 0.0 { 1.0 / d } else { 0.0 };
let mut out = [0u8; Q4_0_BLOCK_BYTES];
out[0..2].copy_from_slice(&f16::from_f32(d).to_le_bytes());
for j in 0..QK / 2 {
let q0 = nibble(x[j] * id);
let q1 = nibble(x[j + QK / 2] * id);
out[2 + j] = q0 | (q1 << 4);
}
out
}
#[inline]
fn nibble(scaled: f32) -> u8 {
let q = (scaled + 8.5) as i32;
q.clamp(0, 15) as u8
}
pub fn dequantize_q4_0_block(block: &[u8], out: &mut [f32]) {
debug_assert_eq!(block.len(), Q4_0_BLOCK_BYTES);
debug_assert_eq!(out.len(), QK);
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
for j in 0..QK / 2 {
let byte = block[2 + j];
let lo = (byte & 0x0f) as i32 - 8;
let hi = (byte >> 4) as i32 - 8;
out[j] = lo as f32 * d;
out[j + QK / 2] = hi as f32 * d;
}
}
#[inline]
pub fn dot_q4_0_block_f32(block: &[u8], x: &[f32]) -> f32 {
debug_assert_eq!(block.len(), Q4_0_BLOCK_BYTES);
debug_assert_eq!(x.len(), QK);
#[cfg(target_arch = "aarch64")]
return unsafe { dot_q4_0_block_neon(block, x) };
#[cfg(not(target_arch = "aarch64"))]
dot_q4_0_block_scalar(block, x)
}
#[inline(always)]
#[allow(dead_code)] fn dot_q4_0_block_scalar(block: &[u8], x: &[f32]) -> f32 {
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let mut acc = 0.0f32;
for j in 0..QK / 2 {
let byte = block[2 + j];
let lo = (byte & 0x0f) as i32 - 8;
let hi = (byte >> 4) as i32 - 8;
acc += lo as f32 * x[j] + hi as f32 * x[j + QK / 2];
}
acc * d
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn dot_q4_0_block_neon(block: &[u8], x: &[f32]) -> f32 {
let scale = f16::from_le_bytes([block[0], block[1]]).to_f32();
let packed_ptr = block.as_ptr().add(2);
let packed = vld1q_u8(packed_ptr);
let lo_u8 = vandq_u8(packed, vdupq_n_u8(0x0F));
let hi_u8 = vshrq_n_u8(packed, 4);
let eight = vdupq_n_u8(8);
let lo_i8 = vreinterpretq_s8_u8(vsubq_u8(lo_u8, eight));
let hi_i8 = vreinterpretq_s8_u8(vsubq_u8(hi_u8, eight));
macro_rules! to_f32x4 {
($i8vec:expr, $half:ident) => {{
let i16v = $half($i8vec);
let lo32 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(i16v)));
let hi32 = vcvtq_f32_s32(vmovl_high_s16(i16v));
(lo32, hi32)
}};
}
let (lo_f32_0, lo_f32_1) = to_f32x4!(lo_i8, vmovl_s8_low);
let (lo_f32_2, lo_f32_3) = to_f32x4!(lo_i8, vmovl_s8_high);
let (hi_f32_0, hi_f32_1) = to_f32x4!(hi_i8, vmovl_s8_low);
let (hi_f32_2, hi_f32_3) = to_f32x4!(hi_i8, vmovl_s8_high);
let xp = x.as_ptr();
let x0 = vld1q_f32(xp);
let x1 = vld1q_f32(xp.add(4));
let x2 = vld1q_f32(xp.add(8));
let x3 = vld1q_f32(xp.add(12));
let x4 = vld1q_f32(xp.add(16));
let x5 = vld1q_f32(xp.add(20));
let x6 = vld1q_f32(xp.add(24));
let x7 = vld1q_f32(xp.add(28));
let mut acc = vmulq_f32(lo_f32_0, x0);
acc = vfmaq_f32(acc, lo_f32_1, x1);
acc = vfmaq_f32(acc, lo_f32_2, x2);
acc = vfmaq_f32(acc, lo_f32_3, x3);
acc = vfmaq_f32(acc, hi_f32_0, x4);
acc = vfmaq_f32(acc, hi_f32_1, x5);
acc = vfmaq_f32(acc, hi_f32_2, x6);
acc = vfmaq_f32(acc, hi_f32_3, x7);
vaddvq_f32(acc) * scale
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn vmovl_s8_low(v: int8x16_t) -> int16x8_t {
vmovl_s8(vget_low_s8(v))
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn vmovl_s8_high(v: int8x16_t) -> int16x8_t {
vmovl_high_s8(v)
}
pub fn dot_q4_0_row_f32(row_blocks: &[u8], x: &[f32]) -> f32 {
let k = x.len();
debug_assert_eq!(k % QK, 0);
let mut acc = 0.0f32;
for (b, chunk) in row_blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
acc += dot_q4_0_block_f32(chunk, &x[b * QK..b * QK + QK]);
}
acc
}
pub fn quantize_q4_0_row(w: &[f32]) -> Vec<u8> {
debug_assert_eq!(w.len() % QK, 0);
let mut out = Vec::with_capacity(w.len() / QK * Q4_0_BLOCK_BYTES);
for chunk in w.chunks_exact(QK) {
out.extend_from_slice(&quantize_q4_0_block(chunk));
}
out
}
#[inline]
pub fn dot_q8_0_block_f32(block: &[u8], x: &[f32]) -> f32 {
debug_assert_eq!(block.len(), Q8_0_BLOCK_BYTES);
debug_assert_eq!(x.len(), QK);
#[cfg(target_arch = "aarch64")]
return unsafe { dot_q8_0_block_neon(block, x) };
#[cfg(not(target_arch = "aarch64"))]
dot_q8_0_block_scalar(block, x)
}
#[inline(always)]
#[allow(dead_code)] fn dot_q8_0_block_scalar(block: &[u8], x: &[f32]) -> f32 {
let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
let mut acc = 0.0f32;
for j in 0..QK {
acc += block[2 + j] as i8 as f32 * x[j];
}
acc * d
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn dot_q8_0_block_neon(block: &[u8], x: &[f32]) -> f32 {
let scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
let q_ptr = block.as_ptr().add(2) as *const i8;
let xp = x.as_ptr();
let mut acc = vdupq_n_f32(0.0);
macro_rules! fma_group {
($qoff:expr, $xoff:expr) => {{
let q8 = vld1_s8(q_ptr.add($qoff));
let q16 = vmovl_s8(q8);
let qlo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
let qhi = vcvtq_f32_s32(vmovl_high_s16(q16));
acc = vfmaq_f32(acc, qlo, vld1q_f32(xp.add($xoff)));
acc = vfmaq_f32(acc, qhi, vld1q_f32(xp.add($xoff + 4)));
}};
}
fma_group!(0, 0);
fma_group!(8, 8);
fma_group!(16, 16);
fma_group!(24, 24);
vaddvq_f32(acc) * scale
}
pub fn dot_q8_0_row_f32(row_blocks: &[u8], x: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { dot_q8_0_row_avx2(row_blocks, x) };
}
let k = x.len();
debug_assert_eq!(k % QK, 0);
let mut acc = 0.0f32;
for (b, chunk) in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
acc += dot_q8_0_block_f32(chunk, &x[b * QK..b * QK + QK]);
}
acc
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn dot_q8_0_row_avx2(row_blocks: &[u8], x: &[f32]) -> f32 {
let k = x.len();
debug_assert_eq!(k % QK, 0);
let mut row_acc = _mm256_setzero_ps();
for (b, block) in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
let scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
let q_ptr = block.as_ptr().add(2) as *const i32; let xp = x.as_ptr().add(b * QK);
let mut block_acc = _mm256_setzero_ps();
for g in 0..4usize {
let q_i32_4 = _mm_loadu_si32(q_ptr.add(2 * g) as *const _); let q_i32_4b = _mm_loadu_si32(q_ptr.add(2 * g + 1) as *const _);
let q_a = _mm256_cvtepi8_epi32(q_i32_4); let q_b = _mm256_cvtepi8_epi32(q_i32_4b); let xv_a = _mm256_loadu_ps(xp.add(g * 8));
let xv_b = _mm256_loadu_ps(xp.add(g * 8 + 4));
let qf_a = _mm256_cvtepi32_ps(q_a);
let qf_b = _mm256_cvtepi32_ps(q_b);
block_acc = _mm256_fmadd_ps(qf_a, xv_a, block_acc);
block_acc = _mm256_fmadd_ps(qf_b, xv_b, block_acc);
}
let scale_v = _mm256_set1_ps(scale);
row_acc = _mm256_fmadd_ps(block_acc, scale_v, row_acc);
}
let lo = _mm256_castps256_ps128(row_acc);
let hi = _mm256_extractf128_ps(row_acc, 1);
let sum4 = _mm_add_ps(lo, hi);
let shuf = _mm_movehdup_ps(sum4);
let sum2 = _mm_add_ps(sum4, shuf);
let sum1 = _mm_add_ss(sum2, _mm_movehl_ps(shuf, sum2));
_mm_cvtss_f32(sum1)
}
pub const QK_K: usize = 256;
pub const Q4_K_BLOCK_BYTES: usize = 144;
pub const Q5_K_BLOCK_BYTES: usize = 176;
pub const Q6_K_BLOCK_BYTES: usize = 210;
#[inline(always)]
fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
if j < 4 {
(scales[j] & 63, scales[j + 4] & 63)
} else {
(
(scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4),
(scales[j + 4] >> 4) | ((scales[j] >> 6) << 4),
)
}
}
pub fn dot_q4_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
let mut acc = 0.0f32;
let mut x_off = 0usize;
for block in row_data.chunks_exact(Q4_K_BLOCK_BYTES) {
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
let scales = &block[4..16];
let qs = &block[16..Q4_K_BLOCK_BYTES];
let mut q_off = 0usize;
let mut is = 0usize;
for _ in 0..(QK_K / 64) {
let (sc1, m1) = get_scale_min_k4(is, scales);
let d1 = d * sc1 as f32;
let m1v = dmin * m1 as f32;
let (sc2, m2) = get_scale_min_k4(is + 1, scales);
let d2 = d * sc2 as f32;
let m2v = dmin * m2 as f32;
for l in 0..32 {
acc += (d1 * (qs[q_off + l] & 0x0F) as f32 - m1v) * x[x_off + l];
acc += (d2 * (qs[q_off + l] >> 4) as f32 - m2v) * x[x_off + l + 32];
}
x_off += 64;
q_off += 32;
is += 2;
}
}
acc
}
pub fn dot_q5_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
let mut acc = 0.0f32;
let mut x_off = 0usize;
for block in row_data.chunks_exact(Q5_K_BLOCK_BYTES) {
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
let scales = &block[4..16];
let qh = &block[16..48];
let ql = &block[48..Q5_K_BLOCK_BYTES];
let mut ql_off = 0usize;
let mut is = 0usize;
let mut u1: u8 = 1;
let mut u2: u8 = 2;
for _ in 0..(QK_K / 64) {
let (sc1, m1) = get_scale_min_k4(is, scales);
let d1 = d * sc1 as f32;
let m1v = dmin * m1 as f32;
let (sc2, m2) = get_scale_min_k4(is + 1, scales);
let d2 = d * sc2 as f32;
let m2v = dmin * m2 as f32;
let qh_byte = qh[is / 8];
for l in 0..32 {
let hi1 = if qh_byte & u1 != 0 { 16.0f32 } else { 0.0 };
let hi2 = if qh_byte & u2 != 0 { 16.0f32 } else { 0.0 };
acc += (d1 * ((ql[ql_off + l] & 0x0F) as f32 + hi1) - m1v) * x[x_off + l];
acc += (d2 * ((ql[ql_off + l] >> 4) as f32 + hi2) - m2v) * x[x_off + l + 32];
}
x_off += 64;
ql_off += 32;
is += 2;
if is % 8 == 0 {
u1 = 1;
u2 = 2;
} else {
u1 <<= 2;
u2 <<= 2;
}
}
}
acc
}
pub fn dot_q6_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
let mut acc = 0.0f32;
let mut x_off = 0usize;
for block in row_data.chunks_exact(Q6_K_BLOCK_BYTES) {
let ql = &block[0..128];
let qh = &block[128..192];
let sc = &block[192..208];
let d = f16::from_le_bytes([block[208], block[209]]).to_f32();
let mut ql_off = 0usize;
let mut qh_off = 0usize;
let mut ib = 0usize;
for _ in 0..(QK_K / 128) {
for l in 0..32 {
let q1 =
(((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4)) as i32 - 32) as f32;
let q2 = (((ql[ql_off + l + 32] & 0x0F) | (((qh[qh_off + l] >> 2) & 3) << 4))
as i32
- 32) as f32;
let q3 = (((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4)) as i32 - 32)
as f32;
let q4 = (((ql[ql_off + l + 32] >> 4) | (((qh[qh_off + l] >> 6) & 3) << 4)) as i32
- 32) as f32;
acc += d * sc[ib] as i8 as f32 * q1 * x[x_off + l];
acc += d * sc[ib + 1] as i8 as f32 * q2 * x[x_off + l + 32];
acc += d * sc[ib + 2] as i8 as f32 * q3 * x[x_off + l + 64];
acc += d * sc[ib + 3] as i8 as f32 * q4 * x[x_off + l + 96];
}
x_off += 128;
ql_off += 64;
qh_off += 32;
ib += 4;
}
}
acc
}
#[cfg(test)]
mod tests {
use super::*;
fn seq(n: usize) -> Vec<f32> {
let mut s: u64 = 0x9E3779B97F4A7C15;
(0..n)
.map(|_| {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
((s >> 40) as f32 / (1u32 << 24) as f32) * 2.0 - 1.0
})
.collect()
}
#[test]
fn q4_0_on_the_fly_dot_matches_dequantized_reference() {
let k = 256;
let w = seq(k);
let x = seq(k).iter().map(|v| v * 0.5).collect::<Vec<_>>();
let blocks = quantize_q4_0_row(&w);
assert_eq!(blocks.len(), k / QK * Q4_0_BLOCK_BYTES);
let mut w_hat = vec![0.0f32; k];
for (b, chunk) in blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
dequantize_q4_0_block(chunk, &mut w_hat[b * QK..b * QK + QK]);
}
let reference: f32 = w_hat.iter().zip(&x).map(|(a, b)| a * b).sum();
let on_the_fly = dot_q4_0_row_f32(&blocks, &x);
assert!(
(on_the_fly - reference).abs() < 1e-3,
"on-the-fly {on_the_fly} vs reference {reference}"
);
}
#[test]
fn q4_0_quantization_error_is_bounded() {
let w = seq(QK * 4);
let blocks = quantize_q4_0_row(&w);
let mut w_hat = vec![0.0f32; w.len()];
for (b, chunk) in blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
dequantize_q4_0_block(chunk, &mut w_hat[b * QK..b * QK + QK]);
}
let max_err = w
.iter()
.zip(&w_hat)
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(max_err < 0.2, "max quant error {max_err} too large");
}
}