#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use half::f16;
use super::super::dequant_k_quants::unpack_q4k_q5k_scales;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn fused_dot_q5k_avx2(act: &[f32], blocks: &[u8], k: usize) -> f32 {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 176;
let num_blocks = k / BLOCK_SIZE;
let mut total_acc = _mm256_setzero_ps();
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
let sc = &block[4..16];
let qh = &block[16..48];
let qs = &block[48..176];
let act_block = &act[b * BLOCK_SIZE..];
let (scales, mins) = unpack_q4k_q5k_scales(sc);
for j in 0..8 {
let dl = d * scales[j] as f32;
let ml = dmin * mins[j] as f32;
let act_sub = &act_block[j * 32..];
let dl_vec = _mm256_set1_ps(dl);
let ml_vec = _mm256_set1_ps(ml);
for g in 0..4 {
let l_base = g * 8;
let idx_base = j * 32 + l_base;
unsafe {
let mut vals = [0i32; 8];
for (i, val) in vals.iter_mut().enumerate() {
let l = l_base + i;
let idx = idx_base + i;
let qs_idx = j * 16 + l / 2;
let low4 = if l % 2 == 0 {
(qs[qs_idx] & 0x0F) as i32
} else {
((qs[qs_idx] >> 4) & 0x0F) as i32
};
let high1 = ((qh[idx / 8] >> (idx % 8)) & 1) as i32;
*val = low4 | (high1 << 4);
}
let q_i32 = _mm256_loadu_si256(vals.as_ptr() as *const __m256i);
let q_f32 = _mm256_cvtepi32_ps(q_i32);
let a = _mm256_loadu_ps(act_sub.as_ptr().add(l_base));
let aq = _mm256_mul_ps(a, q_f32);
total_acc = _mm256_fmadd_ps(dl_vec, aq, total_acc);
total_acc = _mm256_fnmadd_ps(ml_vec, a, total_acc);
}
}
}
}
unsafe { super::dot_f32::hsum_f32_avx2(total_acc) }
}
pub fn fused_dot_q5k(act: &[f32], blocks: &[u8], k: usize) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { fused_dot_q5k_avx2(act, blocks, k) };
}
super::super::fused_dot::fused_dot_row(act, blocks, k, crate::quant::QuantFormat::Q5K)
}
#[cfg(target_arch = "aarch64")]
unsafe {
super::aarch64::fused_q5k::fused_dot_q5k_neon(act, blocks, k)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
super::super::fused_dot::fused_dot_row(act, blocks, k, crate::quant::QuantFormat::Q5K)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quant::cpu::kernels::dequant;
#[test]
fn test_fused_q5k_avx2_vs_dequant() {
let k = 256;
let act: Vec<f32> = (0..k).map(|i| (i as f32) * 0.01).collect();
let mut block = [0u8; 176];
block[0..2].copy_from_slice(&f16::from_f32(1.0).to_le_bytes());
block[2..4].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
block[4..8].fill(0x03);
block[8..12].fill(0x02);
block[12..16].fill(0x10);
block[16..48].fill(0xAA); block[48..176].fill(0x73);
let fused = fused_dot_q5k(&act, &block, k);
let mut dequant_buf = vec![0.0f32; k];
dequant::dequant_q5k(&block, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(fused - reference).abs() < reference.abs() * 1e-4 + 1e-2,
"fused={fused}, reference={reference}, diff={}",
(fused - reference).abs()
);
}
#[test]
fn test_fused_q5k_avx2_multi_block() {
let k = 512;
let act: Vec<f32> = (0..k).map(|i| ((i as f32) * 0.01).sin()).collect();
let mut weight = vec![0u8; 176 * 2];
for blk in 0..2 {
let base = blk * 176;
weight[base..base + 2].copy_from_slice(&f16::from_f32(1.5).to_le_bytes());
weight[base + 2..base + 4].copy_from_slice(&f16::from_f32(0.3).to_le_bytes());
weight[base + 4..base + 8].fill(0x05);
weight[base + 8..base + 12].fill(0x01);
weight[base + 16..base + 48].fill(0x55);
weight[base + 48..base + 176].fill(0xA5);
}
let fused = fused_dot_q5k(&act, &weight, k);
let mut dequant_buf = vec![0.0f32; k];
dequant::dequant_q5k(&weight, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(fused - reference).abs() < reference.abs() * 1e-4 + 1e-2,
"fused={fused}, reference={reference}, diff={}",
(fused - reference).abs()
);
}
}