use half::f16;
#[allow(clippy::needless_range_loop)]
pub fn fused_dot_q2k(act: &[f32], blocks: &[u8], k: usize) -> f32 {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 84;
let num_blocks = k / BLOCK_SIZE;
let mut sumf = 0.0f32;
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let sc = &block[0..16];
let qs = &block[16..80];
let d = f16::from_le_bytes([block[80], block[81]]).to_f32();
let dmin = f16::from_le_bytes([block[82], block[83]]).to_f32();
let act_block = &act[b * BLOCK_SIZE..][..BLOCK_SIZE];
let mut scale_sum = 0.0f32;
let mut min_sum = 0.0f32;
let mut y = 0;
let mut is = 0;
let mut q_offset = 0;
for _n in 0..2 {
let q = &qs[q_offset..];
for shift in (0u8..8).step_by(2) {
let sub_scale = (sc[is] & 0x0F) as f32;
let sub_min = (sc[is] >> 4) as f32;
is += 1;
for l in 0..16 {
let qval = ((q[l] >> shift) & 3) as f32;
scale_sum += act_block[y] * sub_scale * qval;
min_sum += act_block[y] * sub_min;
y += 1;
}
let sub_scale = (sc[is] & 0x0F) as f32;
let sub_min = (sc[is] >> 4) as f32;
is += 1;
for l in 0..16 {
let qval = ((q[16 + l] >> shift) & 3) as f32;
scale_sum += act_block[y] * sub_scale * qval;
min_sum += act_block[y] * sub_min;
y += 1;
}
}
q_offset += 32;
}
sumf += d * scale_sum - dmin * min_sum;
}
sumf
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quant::cpu::kernels::dequant_k_quants;
#[test]
fn test_fused_q2k_vs_dequant() {
let k = 256;
let act: Vec<f32> = (0..k).map(|i| (i as f32 - 128.0) * 0.01).collect();
let mut block = [0u8; 84];
block[0..16].fill(0x23);
block[16..80].fill(0xAA);
block[80..82].copy_from_slice(&f16::from_f32(2.0).to_le_bytes());
block[82..84].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
let fused = fused_dot_q2k(&act, &block, k);
let mut dequant_buf = vec![0.0f32; k];
dequant_k_quants::dequant_q2k(&block, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(fused - reference).abs() < reference.abs() * 1e-4 + 1e-2,
"fused={fused}, reference={reference}",
);
}
}