pub fn fused_q4k_q8_dot(q4k_data: &[u8], q8_blocks: &[Q8_0Block]) -> Result<f32> {
const SUPER_BLOCK_BYTES: usize = 144;
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of super-block size {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let expected_values = num_super_blocks * QK_K; let expected_q8_blocks = expected_values / 32;
if q8_blocks.len() != expected_q8_blocks {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q8_0 block count {} doesn't match expected {} (for {} Q4_K values)",
q8_blocks.len(),
expected_q8_blocks,
expected_values
),
});
}
let mut acc = 0.0f32;
let mut q8_block_idx = 0;
for sb_idx in 0..num_super_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let d = read_f16(&q4k_data[sb_start..sb_start + 2]);
let dmin = read_f16(&q4k_data[sb_start + 2..sb_start + 4]);
let mut scales = [0u8; 12];
scales.copy_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
let qs_start = sb_start + 16;
let qs = &q4k_data[qs_start..qs_start + 128];
for block_idx in 0..8 {
let (scale, min) = extract_scale_min(&scales, block_idx);
let q8_block = &q8_blocks[q8_block_idx];
let q8_scale = q8_block.scale;
q8_block_idx += 1;
let block_start = block_idx * 16;
for byte_idx in 0..16 {
let byte = qs[block_start + byte_idx];
let q8_idx = byte_idx * 2;
#[allow(clippy::cast_possible_wrap)]
let q4_low = (byte & 0x0F) as i8;
let w_low = d * scale * f32::from(q4_low) - dmin * min;
let a_low = q8_scale * f32::from(q8_block.quants[q8_idx]);
acc += w_low * a_low;
#[allow(clippy::cast_possible_wrap)]
let q4_high = ((byte >> 4) & 0x0F) as i8;
let w_high = d * scale * f32::from(q4_high) - dmin * min;
let a_high = q8_scale * f32::from(q8_block.quants[q8_idx + 1]);
acc += w_high * a_high;
}
}
}
Ok(acc)
}