use half::f16;
use super::quantize_act_q8k::Q8K_BLOCK_BYTES;
fn fused_dot_q2k_q8k_scalar(act_q8k: &[u8], weight: &[u8], k: usize) -> f32 {
const Q2K_BLOCK_BYTES: usize = 84;
const Q2K_BLOCK_SIZE: usize = 256;
let num_blocks = k / Q2K_BLOCK_SIZE;
let mut sumf = 0.0f32;
for b in 0..num_blocks {
let q2k = &weight[b * Q2K_BLOCK_BYTES..];
let sc = &q2k[0..16];
let qs = &q2k[16..80];
let d = f16::from_le_bytes([q2k[80], q2k[81]]).to_f32();
let dmin = f16::from_le_bytes([q2k[82], q2k[83]]).to_f32();
let q8k_block = &act_q8k[b * Q8K_BLOCK_BYTES..];
let d8 = f32::from_le_bytes(q8k_block[0..4].try_into().expect("exact-size slice"));
let q8 = &q8k_block[4..260];
let bsums_bytes = &q8k_block[260..292];
let dall = d * d8;
let dmin_all = dmin * d8;
let mut summs = 0i32;
for j in 0..16 {
let bsum = i16::from_le_bytes([bsums_bytes[j * 2], bsums_bytes[j * 2 + 1]]) as i32;
summs += bsum * (sc[j] >> 4) as i32;
}
let mut isum = 0i32;
let mut is = 0usize;
let mut q8_offset = 0usize;
for _n in 0..2 {
let q = &qs[_n * 32..];
for shift in (0u8..8).step_by(2) {
let sub_scale = (sc[is] & 0x0F) as i32;
is += 1;
let mut isuml = 0i32;
for l in 0..16 {
isuml += q8[q8_offset + l] as i8 as i32 * ((q[l] >> shift) & 3) as i32;
}
isum += sub_scale * isuml;
let sub_scale = (sc[is] & 0x0F) as i32;
is += 1;
let mut isuml = 0i32;
for l in 0..16 {
isuml +=
q8[q8_offset + 16 + l] as i8 as i32 * ((q[16 + l] >> shift) & 3) as i32;
}
isum += sub_scale * isuml;
q8_offset += 32;
}
}
sumf += dall * isum as f32 - dmin_all * summs as f32;
}
sumf
}
pub fn fused_dot_q2k_q8k(act_q8k: &[u8], weight: &[u8], k: usize) -> f32 {
fused_dot_q2k_q8k_scalar(act_q8k, weight, k)
}
#[cfg(test)]
mod tests {
use super::super::quantize_act_q8k::{Q8K_BLOCK_BYTES, quantize_f32_to_q8k};
use super::*;
use crate::quant::cpu::kernels::dequant_k_quants;
#[test]
fn test_fused_q2k_q8k_vs_f32_dot() {
let k = 256;
let act: Vec<f32> = (0..k).map(|i| (i as f32 - 128.0) * 0.01).collect();
let mut weight = [0u8; 84];
weight[0..16].fill(0x23); weight[16..80].fill(0xAA); weight[80..82].copy_from_slice(&f16::from_f32(2.0).to_le_bytes());
weight[82..84].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
let mut act_q8k = vec![0u8; Q8K_BLOCK_BYTES];
quantize_f32_to_q8k(&act, &mut act_q8k);
let result = fused_dot_q2k_q8k(&act_q8k, &weight, k);
let mut dequant_buf = vec![0.0f32; k];
dequant_k_quants::dequant_q2k(&weight, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(result - reference).abs() < reference.abs() * 0.05 + 1.0,
"q8k={result}, f32_ref={reference}, diff={}",
(result - reference).abs()
);
}
#[test]
fn test_fused_q2k_q8k_large() {
let k = 4096;
let act: Vec<f32> = (0..k).map(|i| ((i as f32) * 0.003).sin()).collect();
let mut weight = vec![0u8; 84 * 16];
for blk in 0..16u8 {
let base = blk as usize * 84;
weight[base..base + 16].fill(0x12 + blk % 4);
for i in 16..80 {
weight[base + i] = ((blk as usize * 17 + i * 31) % 256) as u8;
}
weight[base + 80..base + 82]
.copy_from_slice(&f16::from_f32(0.5 + blk as f32 * 0.03).to_le_bytes());
weight[base + 82..base + 84]
.copy_from_slice(&f16::from_f32(0.1 + blk as f32 * 0.01).to_le_bytes());
}
let mut act_q8k = vec![0u8; Q8K_BLOCK_BYTES * 16];
quantize_f32_to_q8k(&act, &mut act_q8k);
let result = fused_dot_q2k_q8k(&act_q8k, &weight, k);
let mut dequant_buf = vec![0.0f32; k];
dequant_k_quants::dequant_q2k(&weight, &mut dequant_buf);
let reference: f32 = act.iter().zip(dequant_buf.iter()).map(|(a, b)| a * b).sum();
assert!(
(result - reference).abs() < reference.abs() * 0.05 + 1.0,
"q8k={result}, f32_ref={reference}, diff={}",
(result - reference).abs()
);
}
}