#![allow(missing_docs)]
use super::config::TilingConfig;
pub const Q4K_SUPERBLOCK_SIZE: usize = 256;
pub const Q4K_SUPERBLOCK_BYTES: usize = 144;
#[derive(Debug, Clone)]
pub struct TiledQ4KMatvec {
pub config: TilingConfig,
pub m: usize,
pub k: usize,
}
impl TiledQ4KMatvec {
#[must_use]
pub fn new(m: usize, k: usize) -> Self {
assert!(
k % Q4K_SUPERBLOCK_SIZE == 0,
"K dimension ({}) must be aligned to Q4_K superblock size ({})",
k,
Q4K_SUPERBLOCK_SIZE
);
Self { config: TilingConfig::cpu_avx2_q4k_matvec(), m, k }
}
#[must_use]
pub fn superblocks_per_row(&self) -> usize {
self.k / Q4K_SUPERBLOCK_SIZE
}
#[must_use]
pub fn total_superblocks(&self) -> usize {
self.m * self.superblocks_per_row()
}
#[must_use]
#[inline]
pub fn weight_row_offset(&self, row: usize) -> usize {
row * self.superblocks_per_row() * Q4K_SUPERBLOCK_BYTES
}
#[must_use]
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn optimal_parallel_rows(&self, l2_bytes: usize) -> usize {
let row_bytes = (self.k as f32 * 0.5625) as usize;
let input_bytes = self.k * 4;
let available = l2_bytes.saturating_sub(input_bytes);
(available / row_bytes).max(4)
}
pub fn execute_scalar(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
assert_eq!(weights.len(), self.total_superblocks() * Q4K_SUPERBLOCK_BYTES);
assert_eq!(input.len(), self.k);
assert_eq!(output.len(), self.m);
let superblocks_per_row = self.superblocks_per_row();
for row in 0..self.m {
let mut sum = 0.0f32;
let row_offset = row * superblocks_per_row * Q4K_SUPERBLOCK_BYTES;
for sb in 0..superblocks_per_row {
let sb_offset = row_offset + sb * Q4K_SUPERBLOCK_BYTES;
let sb_data = &weights[sb_offset..sb_offset + Q4K_SUPERBLOCK_BYTES];
let input_offset = sb * Q4K_SUPERBLOCK_SIZE;
sum += self.scalar_superblock_dot(
sb_data,
&input[input_offset..input_offset + Q4K_SUPERBLOCK_SIZE],
);
}
output[row] = sum;
}
}
#[cfg(feature = "parallel")]
pub fn execute_parallel(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
use rayon::prelude::*;
assert_eq!(weights.len(), self.total_superblocks() * Q4K_SUPERBLOCK_BYTES);
assert_eq!(input.len(), self.k);
assert_eq!(output.len(), self.m);
let superblocks_per_row = self.superblocks_per_row();
let row_stride = superblocks_per_row * Q4K_SUPERBLOCK_BYTES;
output.par_iter_mut().enumerate().for_each(|(row, out)| {
let mut sum = 0.0f32;
let row_offset = row * row_stride;
for sb in 0..superblocks_per_row {
let sb_offset = row_offset + sb * Q4K_SUPERBLOCK_BYTES;
let sb_data = &weights[sb_offset..sb_offset + Q4K_SUPERBLOCK_BYTES];
let input_offset = sb * Q4K_SUPERBLOCK_SIZE;
sum += self.scalar_superblock_dot(
sb_data,
&input[input_offset..input_offset + Q4K_SUPERBLOCK_SIZE],
);
}
*out = sum;
});
}
#[cfg(not(feature = "parallel"))]
pub fn execute_parallel(&self, weights: &[u8], input: &[f32], output: &mut [f32]) {
self.execute_scalar(weights, input, output);
}
#[inline]
fn scalar_superblock_dot(&self, sb_data: &[u8], input: &[f32]) -> f32 {
let d = f16_to_f32(sb_data.get(0..2).expect("Q4_K: need ≥2 bytes for d"));
let dmin = f16_to_f32(sb_data.get(2..4).expect("Q4_K: need ≥4 bytes for dmin"));
let scales = sb_data.get(4..16).expect("Q4_K: need ≥16 bytes for scales");
let qs = sb_data.get(16..144).expect("Q4_K: need ≥144 bytes for qs");
let scale_mins = precompute_scales_mins(scales);
let mut sum = 0.0f32;
for pair in 0..4 {
let sb_lo = pair * 2;
let sb_hi = pair * 2 + 1;
let (sc_lo, mn_lo) = scale_mins[sb_lo];
let (sc_hi, mn_hi) = scale_mins[sb_hi];
let d_scale_lo = d * sc_lo;
let dm_lo = dmin * mn_lo;
let d_scale_hi = d * sc_hi;
let dm_hi = dmin * mn_hi;
let q_offset = pair * 32; let input_lo = pair * 64; let input_hi = pair * 64 + 32;
let mut pair_sum = 0.0f32;
for i in 0..32 {
let byte = qs[q_offset + i];
let q_lo = (byte & 0x0F) as f32;
let q_hi = (byte >> 4) as f32;
let val_lo = d_scale_lo * q_lo - dm_lo;
let val_hi = d_scale_hi * q_hi - dm_hi;
pair_sum += val_lo * input[input_lo + i];
pair_sum += val_hi * input[input_hi + i];
}
sum += pair_sum;
}
sum
}
#[must_use]
#[allow(clippy::cast_precision_loss)] pub fn stats(&self) -> TilingStats {
let bytes_per_row = self.superblocks_per_row() * Q4K_SUPERBLOCK_BYTES;
let total_weight_bytes = self.m * bytes_per_row;
let input_bytes = self.k * 4;
let output_bytes = self.m * 4;
TilingStats {
total_weight_bytes,
input_bytes,
output_bytes,
superblocks: self.total_superblocks(),
arithmetic_ops: self.m * self.k * 2, arithmetic_intensity: (self.m * self.k * 2) as f32
/ (total_weight_bytes + input_bytes) as f32,
}
}
}
#[derive(Debug, Clone)]
pub struct TilingStats {
pub total_weight_bytes: usize,
pub input_bytes: usize,
pub output_bytes: usize,
pub superblocks: usize,
pub arithmetic_ops: usize,
pub arithmetic_intensity: f32,
}
#[inline]
pub fn f16_to_f32(bytes: &[u8]) -> f32 {
let bits = u16::from_le_bytes([bytes[0], bytes[1]]);
f16_bits_to_f32(bits)
}
#[inline(always)]
fn f16_bits_to_f32(bits: u16) -> f32 {
let sign = (bits >> 15) & 0x1;
let exponent = (bits >> 10) & 0x1F;
let mantissa = bits & 0x3FF;
if exponent != 0 && exponent != 31 {
let f32_exp = (exponent as u32 + 112) as u32; let f32_mant = (mantissa as u32) << 13; let f32_bits = ((sign as u32) << 31) | (f32_exp << 23) | f32_mant;
return f32::from_bits(f32_bits);
}
f16_special_to_f32(sign, exponent, mantissa)
}
#[cold]
#[inline(never)]
fn f16_special_to_f32(sign: u16, exponent: u16, mantissa: u16) -> f32 {
if exponent == 0 {
if mantissa == 0 {
return if sign == 1 { -0.0 } else { 0.0 };
}
const TWO_POW_NEG_14: f32 = 6.103_515_625e-5; let m = mantissa as f32 * (1.0 / 1024.0);
let result = m * TWO_POW_NEG_14;
return if sign == 1 { -result } else { result };
}
if mantissa == 0 {
if sign == 1 {
f32::NEG_INFINITY
} else {
f32::INFINITY
}
} else {
f32::NAN
}
}
#[inline(always)]
#[allow(clippy::cast_precision_loss)]
pub fn extract_scale_min_6bit(scales: &[u8], idx: usize) -> (f32, f32) {
debug_assert!(scales.len() >= 12, "scales array must be at least 12 bytes");
debug_assert!(idx < 8, "idx must be < 8");
if idx < 4 {
let scale = (scales[idx] & 0x3F) as u32;
let min = (scales[4 + idx] & 0x3F) as u32;
(scale as f32, min as f32)
} else {
let i = idx - 4;
let combo = scales[8 + i];
let sc_low4 = (combo & 0x0F) as u32;
let sc_high2 = ((scales[i] >> 6) & 0x03) as u32;
let scale = sc_low4 | (sc_high2 << 4);
let mn_low4 = ((combo >> 4) & 0x0F) as u32;
let mn_high2 = ((scales[4 + i] >> 6) & 0x03) as u32;
let min = mn_low4 | (mn_high2 << 4);
(scale as f32, min as f32)
}
}
#[inline]
fn precompute_scales_mins(scales: &[u8]) -> [(f32, f32); 8] {
debug_assert!(scales.len() >= 12);
[
extract_scale_min_6bit(scales, 0),
extract_scale_min_6bit(scales, 1),
extract_scale_min_6bit(scales, 2),
extract_scale_min_6bit(scales, 3),
extract_scale_min_6bit(scales, 4),
extract_scale_min_6bit(scales, 5),
extract_scale_min_6bit(scales, 6),
extract_scale_min_6bit(scales, 7),
]
}