#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use half::f16;
use super::super::dequant_k_quants::unpack_q4k_q5k_scales;
use super::quantize_act_q8k::Q8K_BLOCK_BYTES;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn fused_dot_q5k_q8k_avx2(act_q8k: &[u8], weight: &[u8], k: usize) -> f32 {
unsafe {
const Q5K_BLOCK_BYTES: usize = 176;
const Q5K_BLOCK_SIZE: usize = 256;
let num_blocks = k / Q5K_BLOCK_SIZE;
let mut sumf = 0.0f32;
for b in 0..num_blocks {
let q5k = &weight[b * Q5K_BLOCK_BYTES..];
let d = f16::from_le_bytes([q5k[0], q5k[1]]).to_f32();
let dmin = f16::from_le_bytes([q5k[2], q5k[3]]).to_f32();
let sc = &q5k[4..16];
let qh = &q5k[16..48];
let qs = &q5k[48..176];
let (scales, mins) = unpack_q4k_q5k_scales(sc);
let q8k_block = &act_q8k[b * Q8K_BLOCK_BYTES..];
let d8 = f32::from_le_bytes(q8k_block[0..4].try_into().unwrap_unchecked());
let q8_qs = &q8k_block[4..260];
let q8_bsums_ptr = q8k_block.as_ptr().add(260) as *const i16;
let dall = d * d8;
let dmin_all = dmin * d8;
let mut sumi = 0i32;
let mut sumi_mins = 0i32;
for j in 0..4 {
let q4_raw = _mm256_loadu_si256(qs.as_ptr().add(j * 32) as *const __m256i);
let q4_lo_nib = _mm256_and_si256(q4_raw, _mm256_set1_epi8(0x0F));
let q4_hi_nib =
_mm256_and_si256(_mm256_srli_epi16(q4_raw, 4), _mm256_set1_epi8(0x0F));
let qh_lo = load_high_bits_32(qh, j * 64);
let qh_hi = load_high_bits_32(qh, j * 64 + 32);
let q5_lo = _mm256_or_si256(q4_lo_nib, _mm256_slli_epi16(qh_lo, 4));
let q5_hi = _mm256_or_si256(q4_hi_nib, _mm256_slli_epi16(qh_hi, 4));
let mask5 = _mm256_set1_epi8(0x1F);
let q5_lo = _mm256_and_si256(q5_lo, mask5);
let q5_hi = _mm256_and_si256(q5_hi, mask5);
let q8_lo = _mm256_loadu_si256(q8_qs.as_ptr().add(j * 64) as *const __m256i);
let q8_hi = _mm256_loadu_si256(q8_qs.as_ptr().add(j * 64 + 32) as *const __m256i);
let dot_lo = _mm256_maddubs_epi16(q5_lo, q8_lo);
let dot_hi = _mm256_maddubs_epi16(q5_hi, q8_hi);
let scale_lo = _mm256_set1_epi16(scales[j * 2] as i16);
let scale_hi = _mm256_set1_epi16(scales[j * 2 + 1] as i16);
let p_lo = _mm256_madd_epi16(dot_lo, scale_lo);
let p_hi = _mm256_madd_epi16(dot_hi, scale_hi);
let p_sum = _mm256_add_epi32(p_lo, p_hi);
let hi128 = _mm256_extracti128_si256(p_sum, 1);
let lo128 = _mm256_castsi256_si128(p_sum);
let sum128 = _mm_add_epi32(lo128, hi128);
let sum64 = _mm_shuffle_epi32(sum128, 0x4E);
let sum128 = _mm_add_epi32(sum128, sum64);
let sum32 = _mm_shuffle_epi32(sum128, 0xB1);
let sum128 = _mm_add_epi32(sum128, sum32);
sumi += _mm_cvtsi128_si32(sum128);
let bsum_lo =
(*q8_bsums_ptr.add(j * 4) as i32) + (*q8_bsums_ptr.add(j * 4 + 1) as i32);
let bsum_hi =
(*q8_bsums_ptr.add(j * 4 + 2) as i32) + (*q8_bsums_ptr.add(j * 4 + 3) as i32);
sumi_mins += (mins[j * 2] as i32) * bsum_lo + (mins[j * 2 + 1] as i32) * bsum_hi;
}
sumf += dall * sumi as f32 - dmin_all * sumi_mins as f32;
}
sumf
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn load_high_bits_32(qh: &[u8], elem_off: usize) -> __m256i {
let mut bits = [0u8; 32];
for (i, bit) in bits.iter_mut().enumerate() {
let idx = elem_off + i;
*bit = (qh[idx / 8] >> (idx % 8)) & 1;
}
unsafe { _mm256_loadu_si256(bits.as_ptr() as *const __m256i) }
}
pub fn fused_dot_q5k_q8k(act_q8k: &[u8], weight: &[u8], k: usize) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { fused_dot_q5k_q8k_avx2(act_q8k, weight, k) };
}
}
fused_dot_q5k_q8k_scalar(act_q8k, weight, k)
}
fn fused_dot_q5k_q8k_scalar(act_q8k: &[u8], weight: &[u8], k: usize) -> f32 {
const Q5K_BLOCK_BYTES: usize = 176;
const Q5K_BLOCK_SIZE: usize = 256;
let num_blocks = k / Q5K_BLOCK_SIZE;
let mut sumf = 0.0f32;
for b in 0..num_blocks {
let q5k = &weight[b * Q5K_BLOCK_BYTES..];
let d = f16::from_le_bytes([q5k[0], q5k[1]]).to_f32();
let dmin = f16::from_le_bytes([q5k[2], q5k[3]]).to_f32();
let sc = &q5k[4..16];
let qh = &q5k[16..48];
let qs = &q5k[48..176];
let (scales, mins) = unpack_q4k_q5k_scales(sc);
let q8k_block = &act_q8k[b * Q8K_BLOCK_BYTES..];
let d8 = f32::from_le_bytes(q8k_block[0..4].try_into().expect("exact-size slice"));
let q8_qs = &q8k_block[4..260];
let q8_bsums = &q8k_block[260..292];
let dall = d * d8;
let dmin_all = dmin * d8;
let mut sumi = 0i32;
let mut sumi_mins = 0i32;
for j in 0..8 {
let mut dot = 0i32;
for l in 0..32 {
let idx = j * 32 + l;
let qs_idx = j * 16 + l / 2;
let low4 = if l % 2 == 0 {
(qs[qs_idx] & 0x0F) as i32
} else {
((qs[qs_idx] >> 4) & 0x0F) as i32
};
let high1 = ((qh[idx / 8] >> (idx % 8)) & 1) as i32;
let val = low4 | (high1 << 4);
dot += val * q8_qs[idx] as i8 as i32;
}
sumi += dot * scales[j] as i32;
let bsum = i16::from_le_bytes(
q8_bsums[j * 4..j * 4 + 2]
.try_into()
.expect("exact-size slice"),
) as i32
+ i16::from_le_bytes(
q8_bsums[j * 4 + 2..j * 4 + 4]
.try_into()
.expect("exact-size slice"),
) as i32;
sumi_mins += mins[j] as i32 * bsum;
}
sumf += dall * sumi as f32 - dmin_all * sumi_mins as f32;
}
sumf
}
#[cfg(test)]
mod tests {
use super::super::quantize_act_q8k::{Q8K_BLOCK_BYTES, quantize_f32_to_q8k};
use super::*;
use crate::quant::cpu::kernels::dequant;
#[test]
fn test_fused_q5k_q8k_vs_f32_dot() {
let k = 256;
let act: Vec<f32> = (0..k).map(|i| (i as f32 - 128.0) * 0.01).collect();
let mut weight = [0u8; 176];
weight[0..2].copy_from_slice(&f16::from_f32(1.0).to_le_bytes());
weight[2..4].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
weight[4..8].fill(0x03);
weight[8..12].fill(0x02);
weight[12..16].fill(0x10);
weight[16..48].fill(0xAA); weight[48..176].fill(0x73);
let mut act_q8k = vec![0u8; Q8K_BLOCK_BYTES];
quantize_f32_to_q8k(&act, &mut act_q8k);
let result = fused_dot_q5k_q8k(&act_q8k, &weight, k);
let mut dequant_buf = vec![0.0f32; k];
dequant::dequant_q5k(&weight, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(result - reference).abs() < reference.abs() * 0.05 + 2.0,
"q8k={result}, f32_ref={reference}, diff={}",
(result - reference).abs()
);
}
#[test]
fn test_fused_q5k_q8k_large() {
let k = 4096;
let act: Vec<f32> = (0..k).map(|i| ((i as f32) * 0.003).sin()).collect();
let mut weight = vec![0u8; 176 * 16];
for blk in 0..16 {
let base = blk * 176;
weight[base..base + 2]
.copy_from_slice(&f16::from_f32(0.8 + blk as f32 * 0.05).to_le_bytes());
weight[base + 2..base + 4]
.copy_from_slice(&f16::from_f32(0.1 + blk as f32 * 0.01).to_le_bytes());
weight[base + 4..base + 8].fill((blk as u8 % 10) + 1);
weight[base + 8..base + 12].fill((blk as u8 % 5) + 1);
weight[base + 16..base + 48].fill(((blk * 7 + 3) % 256) as u8);
for i in 48..176 {
weight[base + i] = ((blk * 17 + i * 31) % 256) as u8;
}
}
let mut act_q8k = vec![0u8; Q8K_BLOCK_BYTES * 16];
quantize_f32_to_q8k(&act, &mut act_q8k);
let result = fused_dot_q5k_q8k(&act_q8k, &weight, k);
let mut dequant_buf = vec![0.0f32; k];
dequant::dequant_q5k(&weight, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(result - reference).abs() < reference.abs() * 0.02 + 1.0,
"q8k={result}, f32_ref={reference}, diff={}",
(result - reference).abs()
);
}
}