#![allow(missing_docs)]
use super::{f16_to_f32, SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
#[inline(always)]
fn extract_q6k_scalar(ql: &[u8], qh: &[u8], idx: usize) -> i8 {
let ql_byte = ql[idx / 2];
let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
let qh_byte = qh[idx / 4];
let high2 = (qh_byte >> ((idx % 4) * 2)) & 0x03;
(low4 | (high2 << 4)) as i8 - 32
}
#[inline(always)]
fn process_q6k_superblock_scalar(
sb_data: &[u8],
input: &[f32],
input_offset: usize,
in_dim: usize,
) -> f32 {
let ql = sb_data.get(0..128).expect("Q6_K: need ≥128 bytes for ql");
let qh = sb_data.get(128..192).expect("Q6_K: need ≥192 bytes for qh");
let scales = sb_data.get(192..208).expect("Q6_K: need ≥208 bytes for scales");
let d = f16_to_f32(u16::from_le_bytes([sb_data[208], sb_data[209]]));
let mut sum = 0.0f32;
for group in 0..16 {
let scale = (scales[group] as i8) as f32;
let group_offset = group * 16;
for j in 0..16 {
let idx = group_offset + j;
let input_idx = input_offset + idx;
if input_idx >= in_dim {
continue;
}
let q6 = extract_q6k_scalar(ql, qh, idx);
sum += d * scale * q6 as f32 * input[input_idx];
}
}
sum
}
pub fn matmul_q6k_f32_scalar(
q6k_data: &[u8],
input: &[f32],
out_dim: usize,
in_dim: usize,
) -> Vec<f32> {
assert_eq!(input.len(), in_dim, "Input length mismatch");
let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
let mut output = vec![0.0f32; out_dim];
for out_idx in 0..out_dim {
let row_start = out_idx * row_bytes;
let mut sum = 0.0f32;
for sb_idx in 0..num_blocks_per_row {
let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
break;
}
let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
let input_offset = sb_idx * SUPER_BLOCK_SIZE;
sum += process_q6k_superblock_scalar(sb_data, input, input_offset, in_dim);
}
output[out_idx] = sum;
}
output
}
#[inline(always)]
fn extract_q6k_values(ql: &[u8], qh: &[u8], idx_base: usize) -> [i32; 8] {
let mut q6_vals = [0i32; 8];
for i in 0..8 {
let idx = idx_base + i;
let ql_byte = ql[idx / 2];
let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
let qh_byte = qh[idx / 4];
let qh_shift = (idx % 4) * 2;
let high2 = (qh_byte >> qh_shift) & 0x03;
q6_vals[i] = ((low4 | (high2 << 4)) as i32) - 32;
}
q6_vals
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn hsum_q6k_avx2(acc: std::arch::x86_64::__m256) -> f32 {
use std::arch::x86_64::*;
let hi128 = _mm256_extractf128_ps(acc, 1);
let lo128 = _mm256_castps256_ps128(acc);
let sum128 = _mm_add_ps(lo128, hi128);
let hi64 = _mm_movehl_ps(sum128, sum128);
let sum64 = _mm_add_ps(sum128, hi64);
let hi32 = _mm_shuffle_ps(sum64, sum64, 1);
let sum32 = _mm_add_ss(sum64, hi32);
_mm_cvtss_f32(sum32)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn process_q6k_superblock_avx2(
sb_data: &[u8],
input: &[f32],
input_offset: usize,
in_dim: usize,
acc: &mut std::arch::x86_64::__m256,
) {
unsafe {
use std::arch::x86_64::*;
let ql = sb_data.get(0..128).expect("Q6_K: need ≥128 bytes for ql");
let qh = sb_data.get(128..192).expect("Q6_K: need ≥192 bytes for qh");
let scales = sb_data.get(192..208).expect("Q6_K: need ≥208 bytes for scales");
let d = f16_to_f32(u16::from_le_bytes([sb_data[208], sb_data[209]]));
let d_vec = _mm256_set1_ps(d);
for group in 0..16 {
let scale = (scales[group] as i8) as f32;
let ds_vec = _mm256_mul_ps(d_vec, _mm256_set1_ps(scale));
let group_offset = group * 16;
let input_group = input_offset + group_offset;
for half in 0..2 {
let half_offset = half * 8;
let input_base = input_group + half_offset;
if input_base + 8 > in_dim {
continue;
}
let q6_vals = extract_q6k_values(ql, qh, group_offset + half_offset);
let q6_i32 = _mm256_loadu_si256(q6_vals.as_ptr() as *const __m256i);
let q6_f32 = _mm256_cvtepi32_ps(q6_i32);
let x = _mm256_loadu_ps(input.as_ptr().add(input_base));
let dequant = _mm256_mul_ps(ds_vec, q6_f32);
*acc = _mm256_fmadd_ps(dequant, x, *acc);
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn matmul_q6k_f32_avx2(
q6k_data: &[u8],
input: &[f32],
out_dim: usize,
in_dim: usize,
) -> Vec<f32> {
unsafe {
use std::arch::x86_64::*;
let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
let mut output = vec![0.0f32; out_dim];
for out_idx in 0..out_dim {
let row_start = out_idx * row_bytes;
let mut acc = _mm256_setzero_ps();
for sb_idx in 0..num_blocks_per_row {
let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
break;
}
let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
let input_offset = sb_idx * SUPER_BLOCK_SIZE;
process_q6k_superblock_avx2(sb_data, input, input_offset, in_dim, &mut acc);
}
output[out_idx] = hsum_q6k_avx2(acc);
}
output
}
}
#[inline]
pub fn matmul_q6k_f32_dispatch(
q6k_data: &[u8],
input: &[f32],
out_dim: usize,
in_dim: usize,
) -> Vec<f32> {
debug_assert_eq!(input.len(), in_dim, "Q6K dispatch: input length mismatch");
debug_assert!(
q6k_data.len() >= crate::contracts::Q6_K.expected_bytes(out_dim, in_dim),
"Q6K dispatch: buffer too small: {} bytes for [{}, {}] (need {})",
q6k_data.len(),
out_dim,
in_dim,
crate::contracts::Q6_K.expected_bytes(out_dim, in_dim),
);
let total_work = out_dim * in_dim;
if total_work >= 8_000_000 {
return matmul_q6k_f32_parallel(q6k_data, input, out_dim, in_dim);
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { matmul_q6k_f32_avx2(q6k_data, input, out_dim, in_dim) };
}
}
matmul_q6k_f32_scalar(q6k_data, input, out_dim, in_dim)
}
#[cfg(target_arch = "x86_64")]
fn matmul_q6k_f32_parallel(
q6k_data: &[u8],
input: &[f32],
out_dim: usize,
in_dim: usize,
) -> Vec<f32> {
use std::thread;
let num_threads = thread::available_parallelism().map(|p| p.get()).unwrap_or(4).min(12);
let chunk_size = (out_dim + num_threads - 1) / num_threads;
let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
let mut output = vec![0.0f32; out_dim];
let has_avx2 = is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma");
thread::scope(|s| {
let input_ref = input;
let q6k_ref = q6k_data;
let chunks: Vec<_> = output.chunks_mut(chunk_size).enumerate().collect();
for (chunk_idx, chunk) in chunks {
let start_row = chunk_idx * chunk_size;
s.spawn(move || {
if has_avx2 {
unsafe {
compute_chunk_avx2(
q6k_ref,
input_ref,
chunk,
start_row,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
}
} else {
compute_chunk_scalar(
q6k_ref,
input_ref,
chunk,
start_row,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
}
});
}
});
output
}
#[cfg(not(target_arch = "x86_64"))]
fn matmul_q6k_f32_parallel(
q6k_data: &[u8],
input: &[f32],
out_dim: usize,
in_dim: usize,
) -> Vec<f32> {
matmul_q6k_f32_scalar(q6k_data, input, out_dim, in_dim)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn compute_chunk_avx2(
q6k_data: &[u8],
input: &[f32],
chunk: &mut [f32],
start_row: usize,
out_dim: usize,
in_dim: usize,
num_blocks_per_row: usize,
row_bytes: usize,
) {
unsafe {
use std::arch::x86_64::*;
for (local_idx, out_val) in chunk.iter_mut().enumerate() {
let out_idx = start_row + local_idx;
if out_idx >= out_dim {
break;
}
let row_start = out_idx * row_bytes;
let mut acc = _mm256_setzero_ps();
for sb_idx in 0..num_blocks_per_row {
let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
break;
}
let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
let input_offset = sb_idx * SUPER_BLOCK_SIZE;
process_q6k_superblock_avx2(sb_data, input, input_offset, in_dim, &mut acc);
}
*out_val = hsum_q6k_avx2(acc);
}
}
}
pub(crate) fn compute_chunk_scalar(
q6k_data: &[u8],
input: &[f32],
chunk: &mut [f32],
start_row: usize,
out_dim: usize,
in_dim: usize,
num_blocks_per_row: usize,
row_bytes: usize,
) {
for (local_idx, out_val) in chunk.iter_mut().enumerate() {
let out_idx = start_row + local_idx;
if out_idx >= out_dim {
break;
}
let row_start = out_idx * row_bytes;
let mut sum = 0.0f32;
for sb_idx in 0..num_blocks_per_row {
let sb_start = row_start + sb_idx * SUPER_BLOCK_BYTES;
if sb_start + SUPER_BLOCK_BYTES > q6k_data.len() {
break;
}
let sb_data = &q6k_data[sb_start..sb_start + SUPER_BLOCK_BYTES];
let input_offset = sb_idx * SUPER_BLOCK_SIZE;
sum += process_q6k_superblock_scalar(sb_data, input, input_offset, in_dim);
}
*out_val = sum;
}
}
pub fn matmul_q6k_f32(q6k_data: &[u8], input: &[f32], out_dim: usize, in_dim: usize) -> Vec<f32> {
matmul_q6k_f32_dispatch(q6k_data, input, out_dim, in_dim)
}