#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use half::f16;
use super::super::dequant_k_quants::unpack_q4k_q5k_scales;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn fused_dot_q4k_avx2(act: &[f32], blocks: &[u8], k: usize) -> f32 {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 144;
let num_blocks = k / BLOCK_SIZE;
let mut total_acc = _mm256_setzero_ps();
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
let sc = &block[4..16];
let qs = &block[16..144];
let act_block = &act[b * BLOCK_SIZE..];
let (scales, mins) = unpack_q4k_q5k_scales(sc);
for j in 0..8 {
let dl = d * scales[j] as f32;
let ml = dmin * mins[j] as f32;
let chunk = j / 2;
let is_high = j % 2 == 1;
let qs_base = chunk * 32;
let act_sub = &act_block[j * 32..];
let dl_vec = _mm256_set1_ps(dl);
let ml_vec = _mm256_set1_ps(ml);
for g in 0..4 {
let l_base = g * 8;
unsafe {
let qs_ptr = qs.as_ptr().add(qs_base + l_base);
let raw = _mm_loadl_epi64(qs_ptr as *const __m128i);
let raw256 = _mm256_cvtepu8_epi32(raw);
let nibbles = if is_high {
_mm256_srli_epi32(raw256, 4)
} else {
raw256
};
let nibbles = _mm256_and_si256(nibbles, _mm256_set1_epi32(0x0F));
let q_f32 = _mm256_cvtepi32_ps(nibbles);
let a = _mm256_loadu_ps(act_sub.as_ptr().add(l_base));
let aq = _mm256_mul_ps(a, q_f32);
total_acc = _mm256_fmadd_ps(dl_vec, aq, total_acc);
total_acc = _mm256_fnmadd_ps(ml_vec, a, total_acc);
}
}
}
}
unsafe { super::dot_f32::hsum_f32_avx2(total_acc) }
}
pub fn fused_dot_q4k(act: &[f32], blocks: &[u8], k: usize) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { fused_dot_q4k_avx2(act, blocks, k) };
}
super::super::fused_dot::fused_dot_row(act, blocks, k, crate::quant::QuantFormat::Q4K)
}
#[cfg(target_arch = "aarch64")]
unsafe {
super::aarch64::fused_q4k::fused_dot_q4k_neon(act, blocks, k)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
super::super::fused_dot::fused_dot_row(act, blocks, k, crate::quant::QuantFormat::Q4K)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quant::cpu::kernels::dequant;
#[test]
fn test_fused_q4k_avx2_vs_dequant() {
let k = 256;
let act: Vec<f32> = (0..k).map(|i| (i as f32) * 0.01).collect();
let mut block = [0u8; 144];
block[0..2].copy_from_slice(&f16::from_f32(1.0).to_le_bytes());
block[2..4].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
block[4..8].fill(0x03);
block[8..12].fill(0x02);
block[12..16].fill(0x10);
block[16..144].fill(0x73);
let fused = fused_dot_q4k(&act, &block, k);
let mut dequant_buf = vec![0.0f32; k];
dequant::dequant_q4k(&block, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(fused - reference).abs() < reference.abs() * 1e-4 + 1e-2,
"fused={fused}, reference={reference}, diff={}",
(fused - reference).abs()
);
}
#[test]
fn test_fused_q4k_avx2_multi_block() {
let k = 512;
let act: Vec<f32> = (0..k).map(|i| ((i as f32) * 0.01).sin()).collect();
let mut weight = vec![0u8; 144 * 2];
for blk in 0..2 {
let base = blk * 144;
weight[base..base + 2].copy_from_slice(&f16::from_f32(1.5).to_le_bytes());
weight[base + 2..base + 4].copy_from_slice(&f16::from_f32(0.3).to_le_bytes());
weight[base + 4..base + 8].fill(0x05);
weight[base + 8..base + 12].fill(0x01);
weight[base + 16..base + 144].fill(0xA5);
}
let fused = fused_dot_q4k(&act, &weight, k);
let mut dequant_buf = vec![0.0f32; k];
dequant::dequant_q4k(&weight, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(fused - reference).abs() < reference.abs() * 1e-4 + 1e-2,
"fused={fused}, reference={reference}, diff={}",
(fused - reference).abs()
);
}
#[test]
fn test_fused_q4k_avx2_large() {
let k = 4096;
let act: Vec<f32> = (0..k).map(|i| ((i as f32) * 0.003).sin()).collect();
let mut weight = vec![0u8; 144 * 16]; for blk in 0..16 {
let base = blk * 144;
weight[base..base + 2]
.copy_from_slice(&f16::from_f32(0.8 + blk as f32 * 0.05).to_le_bytes());
weight[base + 2..base + 4]
.copy_from_slice(&f16::from_f32(0.1 + blk as f32 * 0.01).to_le_bytes());
weight[base + 4..base + 8].fill((blk as u8 % 10) + 1);
weight[base + 8..base + 12].fill((blk as u8 % 5) + 1);
for i in 16..144 {
weight[base + i] = ((blk * 17 + i * 31) % 256) as u8;
}
}
let fused = fused_dot_q4k(&act, &weight, k);
let mut dequant_buf = vec![0.0f32; k];
dequant::dequant_q4k(&weight, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(fused - reference).abs() < reference.abs() * 1e-4 + 1e-1,
"fused={fused}, reference={reference}, diff={}",
(fused - reference).abs()
);
}
}