#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use half::f16;
use super::super::super::dequant_k_quants::unpack_q4k_q5k_scales;
use super::dot_f32::hsum_f32_neon;
const F32_LANES: usize = 4;
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn fused_dot_q4k_neon(act: &[f32], blocks: &[u8], k: usize) -> f32 {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 144;
let num_blocks = k / BLOCK_SIZE;
debug_assert!(act.len() >= k, "act.len() {} < k {}", act.len(), k);
debug_assert!(
blocks.len() >= num_blocks * BLOCK_BYTES,
"blocks.len() {} < required {}",
blocks.len(),
num_blocks * BLOCK_BYTES
);
let mut total_acc = vdupq_n_f32(0.0);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
let sc = &block[4..16];
let qs = &block[16..144];
let act_block = &act[b * BLOCK_SIZE..];
let (scales, mins) = unpack_q4k_q5k_scales(sc);
for j in 0..8 {
let dl = d * scales[j] as f32;
let ml = dmin * mins[j] as f32;
let chunk = j / 2;
let is_high = j % 2 == 1;
let qs_base = chunk * 32;
let act_sub = &act_block[j * 32..];
let dl_vec = vdupq_n_f32(dl);
let ml_vec = vdupq_n_f32(ml);
let mask_0f = vdupq_n_u32(0x0F);
for g in 0..8 {
let l_base = g * 4;
let q0 = qs[qs_base + l_base] as u32;
let q1 = qs[qs_base + l_base + 1] as u32;
let q2 = qs[qs_base + l_base + 2] as u32;
let q3 = qs[qs_base + l_base + 3] as u32;
let raw = vcreate_u32(q0 as u64 | (q1 as u64) << 32);
let raw_hi = vcreate_u32(q2 as u64 | (q3 as u64) << 32);
let raw256 = vcombine_u32(raw, raw_hi);
let nibbles = if is_high {
vandq_u32(vshrq_n_u32::<4>(raw256), mask_0f)
} else {
vandq_u32(raw256, mask_0f)
};
let q_f32 = vcvtq_f32_u32(nibbles);
let a = vld1q_f32(act_sub.as_ptr().add(l_base));
let aq = vmulq_f32(a, q_f32);
total_acc = vfmaq_f32(total_acc, dl_vec, aq);
total_acc = vsubq_f32(total_acc, vmulq_f32(ml_vec, a));
}
}
}
hsum_f32_neon(total_acc)
}