#![cfg(all(feature = "simd-avx512", target_arch = "x86_64"))]
use core::arch::x86_64::*;
use crate::error::{QuantError, QuantResult};
use crate::simd::avx512::util::{f16_to_f32, hsum_f32_avx512};
use crate::traits::QuantKernel;
use crate::types::QuantTensor;
pub const BLOCK_SIZE: usize = 256;
pub const BLOCK_BYTES: usize = 210;
#[allow(non_camel_case_types)]
pub struct Q6_KAvx512;
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn decode_q6_group(
ql_ptr: *const u8,
qh_ptr: *const u8,
) -> (__m128i, __m128i, __m128i, __m128i) {
let ql0 = _mm_loadu_si128(ql_ptr as *const __m128i); let ql1 = _mm_loadu_si128(ql_ptr.add(32) as *const __m128i);
let qh_raw = _mm_loadu_si128(qh_ptr as *const __m128i);
let mask4 = _mm_set1_epi8(0x0F_u8 as i8);
let mask2 = _mm_set1_epi8(0x03_u8 as i8);
let ql0_lo = _mm_and_si128(ql0, mask4); let ql1_lo = _mm_and_si128(ql1, mask4);
let ql0_hi = _mm_and_si128(_mm_srli_epi16(ql0, 4), mask4); let ql1_hi = _mm_and_si128(_mm_srli_epi16(ql1, 4), mask4);
let qh_sh0 = _mm_and_si128(qh_raw, mask2); let qh_sh2 = _mm_and_si128(_mm_srli_epi16(qh_raw, 2), mask2); let qh_sh4 = _mm_and_si128(_mm_srli_epi16(qh_raw, 4), mask2); let qh_sh6 = _mm_and_si128(_mm_srli_epi16(qh_raw, 6), mask2);
let qh_hi0 = _mm_slli_epi16(qh_sh0, 4);
let qh_hi2 = _mm_slli_epi16(qh_sh2, 4);
let qh_hi4 = _mm_slli_epi16(qh_sh4, 4);
let qh_hi6 = _mm_slli_epi16(qh_sh6, 4);
let q1 = _mm_or_si128(ql0_lo, qh_hi0);
let q2 = _mm_or_si128(ql1_lo, qh_hi2);
let q3 = _mm_or_si128(ql0_hi, qh_hi4);
let q4 = _mm_or_si128(ql1_hi, qh_hi6);
(q1, q2, q3, q4)
}
impl QuantKernel for Q6_KAvx512 {
fn dequant_block(&self, block: &[u8], output: &mut [f32]) -> QuantResult<()> {
if block.len() < BLOCK_BYTES {
return Err(QuantError::BufferTooSmall {
needed: BLOCK_BYTES,
available: block.len(),
});
}
if output.len() < BLOCK_SIZE {
return Err(QuantError::BufferTooSmall {
needed: BLOCK_SIZE,
available: output.len(),
});
}
unsafe { dequant_block_avx512(block, output) }
Ok(())
}
fn gemv(
&self,
quant_matrix: &QuantTensor,
input: &[f32],
output: &mut [f32],
) -> QuantResult<()> {
let n_rows = quant_matrix.shape[0];
let n_cols = if quant_matrix.shape.len() > 1 {
quant_matrix.shape[1]
} else {
quant_matrix.n_elements() / n_rows
};
if input.len() < n_cols {
return Err(QuantError::DimensionMismatch {
expected: n_cols,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(QuantError::DimensionMismatch {
expected: n_rows,
got: output.len(),
});
}
let blocks_per_row = n_cols.div_ceil(BLOCK_SIZE);
let row_bytes = blocks_per_row * BLOCK_BYTES;
for (row, out) in output.iter_mut().enumerate().take(n_rows) {
let row_start = row * row_bytes;
*out = unsafe {
gemv_row_avx512(
&quant_matrix.data[row_start..row_start + row_bytes],
input,
blocks_per_row,
n_cols,
)
};
}
Ok(())
}
fn gemm(
&self,
quant_matrix: &QuantTensor,
input: &[f32],
output: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> QuantResult<()> {
for row in 0..m {
let input_row = &input[row * k..(row + 1) * k];
let output_row = &mut output[row * n..(row + 1) * n];
self.gemv(quant_matrix, input_row, output_row)?;
}
Ok(())
}
fn block_size(&self) -> usize {
BLOCK_SIZE
}
fn block_bytes(&self) -> usize {
BLOCK_BYTES
}
fn name(&self) -> &'static str {
"Q6_K"
}
}
#[target_feature(enable = "avx512f")]
unsafe fn dequant_block_avx512(block: &[u8], output: &mut [f32]) {
let ql = &block[0..128];
let qh = &block[128..192];
let scales = &block[192..208];
let d = f16_to_f32(&block[208..]);
let off32 = _mm512_set1_epi32(32);
for group in 0..2usize {
let ql_off = group * 64;
let qh_off = group * 32;
let sc_off = group * 8;
let out_off = group * 128;
let (q1_a, q2_a, q3_a, q4_a) =
decode_q6_group(ql.as_ptr().add(ql_off), qh.as_ptr().add(qh_off));
let (q1_b, q2_b, q3_b, q4_b) =
decode_q6_group(ql.as_ptr().add(ql_off + 16), qh.as_ptr().add(qh_off + 16));
let s0a = d * (*scales.get_unchecked(sc_off)) as i8 as f32;
let s1a = d * (*scales.get_unchecked(sc_off + 2)) as i8 as f32;
let s2a = d * (*scales.get_unchecked(sc_off + 4)) as i8 as f32;
let s3a = d * (*scales.get_unchecked(sc_off + 6)) as i8 as f32;
let s0b = d * (*scales.get_unchecked(sc_off + 1)) as i8 as f32;
let s1b = d * (*scales.get_unchecked(sc_off + 3)) as i8 as f32;
let s2b = d * (*scales.get_unchecked(sc_off + 5)) as i8 as f32;
let s3b = d * (*scales.get_unchecked(sc_off + 7)) as i8 as f32;
let q1a_i32 = _mm512_sub_epi32(_mm512_cvtepu8_epi32(q1_a), off32);
let q1a_f32 = _mm512_cvtepi32_ps(q1a_i32);
let vs0a = _mm512_set1_ps(s0a);
let w_q1a = _mm512_mul_ps(vs0a, q1a_f32);
let q1b_i32 = _mm512_sub_epi32(_mm512_cvtepu8_epi32(q1_b), off32);
let q1b_f32 = _mm512_cvtepi32_ps(q1b_i32);
let vs0b = _mm512_set1_ps(s0b);
let w_q1b = _mm512_mul_ps(vs0b, q1b_f32);
let ptr_q1 = output.as_mut_ptr().add(out_off);
_mm512_storeu_ps(ptr_q1, w_q1a);
_mm512_storeu_ps(ptr_q1.add(16), w_q1b);
let q2a_i32 = _mm512_sub_epi32(_mm512_cvtepu8_epi32(q2_a), off32);
let q2a_f32 = _mm512_cvtepi32_ps(q2a_i32);
let vs1a = _mm512_set1_ps(s1a);
let w_q2a = _mm512_mul_ps(vs1a, q2a_f32);
let q2b_i32 = _mm512_sub_epi32(_mm512_cvtepu8_epi32(q2_b), off32);
let q2b_f32 = _mm512_cvtepi32_ps(q2b_i32);
let vs1b = _mm512_set1_ps(s1b);
let w_q2b = _mm512_mul_ps(vs1b, q2b_f32);
let ptr_q2 = output.as_mut_ptr().add(out_off + 32);
_mm512_storeu_ps(ptr_q2, w_q2a);
_mm512_storeu_ps(ptr_q2.add(16), w_q2b);
let q3a_i32 = _mm512_sub_epi32(_mm512_cvtepu8_epi32(q3_a), off32);
let q3a_f32 = _mm512_cvtepi32_ps(q3a_i32);
let vs2a = _mm512_set1_ps(s2a);
let w_q3a = _mm512_mul_ps(vs2a, q3a_f32);
let q3b_i32 = _mm512_sub_epi32(_mm512_cvtepu8_epi32(q3_b), off32);
let q3b_f32 = _mm512_cvtepi32_ps(q3b_i32);
let vs2b = _mm512_set1_ps(s2b);
let w_q3b = _mm512_mul_ps(vs2b, q3b_f32);
let ptr_q3 = output.as_mut_ptr().add(out_off + 64);
_mm512_storeu_ps(ptr_q3, w_q3a);
_mm512_storeu_ps(ptr_q3.add(16), w_q3b);
let q4a_i32 = _mm512_sub_epi32(_mm512_cvtepu8_epi32(q4_a), off32);
let q4a_f32 = _mm512_cvtepi32_ps(q4a_i32);
let vs3a = _mm512_set1_ps(s3a);
let w_q4a = _mm512_mul_ps(vs3a, q4a_f32);
let q4b_i32 = _mm512_sub_epi32(_mm512_cvtepu8_epi32(q4_b), off32);
let q4b_f32 = _mm512_cvtepi32_ps(q4b_i32);
let vs3b = _mm512_set1_ps(s3b);
let w_q4b = _mm512_mul_ps(vs3b, q4b_f32);
let ptr_q4 = output.as_mut_ptr().add(out_off + 96);
_mm512_storeu_ps(ptr_q4, w_q4a);
_mm512_storeu_ps(ptr_q4.add(16), w_q4b);
}
}
#[target_feature(enable = "avx512f")]
unsafe fn gemv_row_avx512(
row_data: &[u8],
input: &[f32],
blocks_per_row: usize,
n_cols: usize,
) -> f32 {
let mut row_sum = 0.0f32;
for blk in 0..blocks_per_row {
let block_offset = blk * BLOCK_BYTES;
let block = &row_data[block_offset..block_offset + BLOCK_BYTES];
let input_offset = blk * BLOCK_SIZE;
let remaining = n_cols.saturating_sub(input_offset);
let ql = &block[0..128];
let qh = &block[128..192];
let scales = &block[192..208];
let d = f16_to_f32(&block[208..]);
if remaining >= BLOCK_SIZE {
let mut block_acc = _mm512_setzero_ps();
let off32 = _mm512_set1_epi32(32);
for group in 0..2usize {
let ql_off = group * 64;
let qh_off = group * 32;
let sc_off = group * 8;
let w_off = input_offset + group * 128;
let s0a = d * (*scales.get_unchecked(sc_off)) as i8 as f32;
let s1a = d * (*scales.get_unchecked(sc_off + 2)) as i8 as f32;
let s2a = d * (*scales.get_unchecked(sc_off + 4)) as i8 as f32;
let s3a = d * (*scales.get_unchecked(sc_off + 6)) as i8 as f32;
let s0b = d * (*scales.get_unchecked(sc_off + 1)) as i8 as f32;
let s1b = d * (*scales.get_unchecked(sc_off + 3)) as i8 as f32;
let s2b = d * (*scales.get_unchecked(sc_off + 5)) as i8 as f32;
let s3b = d * (*scales.get_unchecked(sc_off + 7)) as i8 as f32;
let vs0a = _mm512_set1_ps(s0a);
let vs1a = _mm512_set1_ps(s1a);
let vs2a = _mm512_set1_ps(s2a);
let vs3a = _mm512_set1_ps(s3a);
let vs0b = _mm512_set1_ps(s0b);
let vs1b = _mm512_set1_ps(s1b);
let vs2b = _mm512_set1_ps(s2b);
let vs3b = _mm512_set1_ps(s3b);
let (q1_a, q2_a, q3_a, q4_a) =
decode_q6_group(ql.as_ptr().add(ql_off), qh.as_ptr().add(qh_off));
let (q1_b, q2_b, q3_b, q4_b) =
decode_q6_group(ql.as_ptr().add(ql_off + 16), qh.as_ptr().add(qh_off + 16));
let inp_q1 = input.as_ptr().add(w_off);
let q1a_f = _mm512_cvtepi32_ps(_mm512_sub_epi32(_mm512_cvtepu8_epi32(q1_a), off32));
block_acc = _mm512_fmadd_ps(
_mm512_mul_ps(vs0a, q1a_f),
_mm512_loadu_ps(inp_q1),
block_acc,
);
let q1b_f = _mm512_cvtepi32_ps(_mm512_sub_epi32(_mm512_cvtepu8_epi32(q1_b), off32));
block_acc = _mm512_fmadd_ps(
_mm512_mul_ps(vs0b, q1b_f),
_mm512_loadu_ps(inp_q1.add(16)),
block_acc,
);
let inp_q2 = input.as_ptr().add(w_off + 32);
let q2a_f = _mm512_cvtepi32_ps(_mm512_sub_epi32(_mm512_cvtepu8_epi32(q2_a), off32));
block_acc = _mm512_fmadd_ps(
_mm512_mul_ps(vs1a, q2a_f),
_mm512_loadu_ps(inp_q2),
block_acc,
);
let q2b_f = _mm512_cvtepi32_ps(_mm512_sub_epi32(_mm512_cvtepu8_epi32(q2_b), off32));
block_acc = _mm512_fmadd_ps(
_mm512_mul_ps(vs1b, q2b_f),
_mm512_loadu_ps(inp_q2.add(16)),
block_acc,
);
let inp_q3 = input.as_ptr().add(w_off + 64);
let q3a_f = _mm512_cvtepi32_ps(_mm512_sub_epi32(_mm512_cvtepu8_epi32(q3_a), off32));
block_acc = _mm512_fmadd_ps(
_mm512_mul_ps(vs2a, q3a_f),
_mm512_loadu_ps(inp_q3),
block_acc,
);
let q3b_f = _mm512_cvtepi32_ps(_mm512_sub_epi32(_mm512_cvtepu8_epi32(q3_b), off32));
block_acc = _mm512_fmadd_ps(
_mm512_mul_ps(vs2b, q3b_f),
_mm512_loadu_ps(inp_q3.add(16)),
block_acc,
);
let inp_q4 = input.as_ptr().add(w_off + 96);
let q4a_f = _mm512_cvtepi32_ps(_mm512_sub_epi32(_mm512_cvtepu8_epi32(q4_a), off32));
block_acc = _mm512_fmadd_ps(
_mm512_mul_ps(vs3a, q4a_f),
_mm512_loadu_ps(inp_q4),
block_acc,
);
let q4b_f = _mm512_cvtepi32_ps(_mm512_sub_epi32(_mm512_cvtepu8_epi32(q4_b), off32));
block_acc = _mm512_fmadd_ps(
_mm512_mul_ps(vs3b, q4b_f),
_mm512_loadu_ps(inp_q4.add(16)),
block_acc,
);
}
row_sum += hsum_f32_avx512(block_acc);
} else if remaining > 0 {
let mut partial_sum = 0.0f32;
for group in 0..2usize {
let ql_off = group * 64;
let qh_off = group * 32;
let sc_off = group * 8;
let in_off = input_offset + group * 128;
for l in 0..32 {
let is = l / 16;
let ql_l = *ql.get_unchecked(ql_off + l);
let ql_l32 = *ql.get_unchecked(ql_off + l + 32);
let qh_l = *qh.get_unchecked(qh_off + l);
let q1 = ((ql_l & 0x0F) | ((qh_l & 3) << 4)) as i32 - 32;
let q2 = ((ql_l32 & 0x0F) | (((qh_l >> 2) & 3) << 4)) as i32 - 32;
let q3 = ((ql_l >> 4) | (((qh_l >> 4) & 3) << 4)) as i32 - 32;
let q4 = ((ql_l32 >> 4) | (((qh_l >> 6) & 3) << 4)) as i32 - 32;
let s0 = d * (*scales.get_unchecked(sc_off + is)) as i8 as f32;
let s1 = d * (*scales.get_unchecked(sc_off + is + 2)) as i8 as f32;
let s2 = d * (*scales.get_unchecked(sc_off + is + 4)) as i8 as f32;
let s3 = d * (*scales.get_unchecked(sc_off + is + 6)) as i8 as f32;
let idx0 = in_off + l;
let idx1 = in_off + l + 32;
let idx2 = in_off + l + 64;
let idx3 = in_off + l + 96;
if idx0 < n_cols {
partial_sum += s0 * q1 as f32 * input[idx0];
}
if idx1 < n_cols {
partial_sum += s1 * q2 as f32 * input[idx1];
}
if idx2 < n_cols {
partial_sum += s2 * q3 as f32 * input[idx2];
}
if idx3 < n_cols {
partial_sum += s3 * q4 as f32 * input[idx3];
}
}
}
row_sum += partial_sum;
}
}
row_sum
}
#[cfg(all(test, target_arch = "x86_64", feature = "simd-avx512"))]
mod tests {
use super::*;
use crate::reference::q6_k::Q6KRef;
fn make_q6k_block(d: f32, ql: &[u8; 128], qh: &[u8; 64], scales: &[u8; 16]) -> Vec<u8> {
let mut block = Vec::with_capacity(BLOCK_BYTES);
block.extend_from_slice(ql);
block.extend_from_slice(qh);
block.extend_from_slice(scales);
block.extend_from_slice(&half::f16::from_f32(d).to_bits().to_le_bytes());
block
}
fn make_tensor(block: Vec<u8>, n_cols: usize) -> crate::types::QuantTensor {
crate::types::QuantTensor::new(block, vec![1, n_cols], oxillama_gguf::GgufTensorType::Q6K)
}
#[test]
#[cfg_attr(not(target_feature = "avx512f"), ignore)]
fn test_q6k_avx512_dequant_matches_reference_zero() {
if !std::arch::is_x86_feature_detected!("avx512f") {
return;
}
let block = make_q6k_block(0.0, &[0; 128], &[0; 64], &[0; 16]);
let mut out_avx512 = vec![0.0f32; 256];
let mut out_ref = vec![0.0f32; 256];
Q6_KAvx512
.dequant_block(&block, &mut out_avx512)
.expect("avx512 dequant");
Q6KRef
.dequant_block(&block, &mut out_ref)
.expect("ref dequant");
for (i, (&a, &r)) in out_avx512.iter().zip(out_ref.iter()).enumerate() {
assert!(
(a - r).abs() < 1e-5,
"dequant mismatch [zero] at index {i}: avx512={a}, ref={r}"
);
}
}
#[test]
#[cfg_attr(not(target_feature = "avx512f"), ignore)]
fn test_q6k_avx512_dequant_matches_reference_quant32() {
if !std::arch::is_x86_feature_detected!("avx512f") {
return;
}
let qh = [0xAAu8; 64];
let scales: [u8; 16] = [1; 16];
let block = make_q6k_block(1.0, &[0; 128], &qh, &scales);
let mut out_avx512 = vec![0.0f32; 256];
let mut out_ref = vec![0.0f32; 256];
Q6_KAvx512
.dequant_block(&block, &mut out_avx512)
.expect("avx512 dequant");
Q6KRef
.dequant_block(&block, &mut out_ref)
.expect("ref dequant");
for (i, (&a, &r)) in out_avx512.iter().zip(out_ref.iter()).enumerate() {
assert!(
(a - r).abs() < 1e-4,
"dequant mismatch [quant32] at index {i}: avx512={a}, ref={r}"
);
}
}
#[test]
#[cfg_attr(not(target_feature = "avx512f"), ignore)]
fn test_q6k_avx512_dequant_matches_reference_varied() {
if !std::arch::is_x86_feature_detected!("avx512f") {
return;
}
let mut ql = [0u8; 128];
for (i, b) in ql.iter_mut().enumerate() {
*b = ((i * 7 + 3) & 0xFF) as u8;
}
let mut qh = [0u8; 64];
for (i, b) in qh.iter_mut().enumerate() {
*b = ((i * 13 + 5) & 0xFF) as u8;
}
let mut scales = [0u8; 16];
for (i, s) in scales.iter_mut().enumerate() {
*s = (i as i8 * 3 - 8) as u8;
}
let block = make_q6k_block(0.5, &ql, &qh, &scales);
let mut out_avx512 = vec![0.0f32; 256];
let mut out_ref = vec![0.0f32; 256];
Q6_KAvx512
.dequant_block(&block, &mut out_avx512)
.expect("avx512 dequant");
Q6KRef
.dequant_block(&block, &mut out_ref)
.expect("ref dequant");
for (i, (&a, &r)) in out_avx512.iter().zip(out_ref.iter()).enumerate() {
assert!(
(a - r).abs() < 1e-3,
"dequant mismatch [varied] at index {i}: avx512={a}, ref={r}"
);
}
}
#[test]
#[cfg_attr(not(target_feature = "avx512f"), ignore)]
fn test_q6k_avx512_gemv_matches_reference() {
if !std::arch::is_x86_feature_detected!("avx512f") {
return;
}
let mut ql = [0u8; 128];
for (i, b) in ql.iter_mut().enumerate() {
*b = ((i * 7 + 3) & 0xFF) as u8;
}
let mut qh = [0u8; 64];
for (i, b) in qh.iter_mut().enumerate() {
*b = ((i * 13 + 5) & 0xFF) as u8;
}
let mut scales = [0u8; 16];
for (i, s) in scales.iter_mut().enumerate() {
*s = (i as i8 * 3 - 8) as u8;
}
let block = make_q6k_block(0.5, &ql, &qh, &scales);
let tensor_avx512 = make_tensor(block.clone(), 256);
let tensor_ref = make_tensor(block, 256);
let input: Vec<f32> = (0..256).map(|i| (i as f32 * 0.01) - 1.28).collect();
let mut out_avx512 = vec![0.0f32; 1];
let mut out_ref = vec![0.0f32; 1];
Q6_KAvx512
.gemv(&tensor_avx512, &input, &mut out_avx512)
.expect("avx512 gemv");
Q6KRef
.gemv(&tensor_ref, &input, &mut out_ref)
.expect("ref gemv");
assert!(
(out_avx512[0] - out_ref[0]).abs() < 1e-2,
"gemv mismatch: avx512={}, ref={}",
out_avx512[0],
out_ref[0]
);
}
#[test]
#[cfg_attr(not(target_feature = "avx512f"), ignore)]
fn test_q6k_avx512_gemv_partial_block() {
if !std::arch::is_x86_feature_detected!("avx512f") {
return;
}
let scales = [1i8 as u8; 16];
let block = make_q6k_block(1.0, &[0; 128], &[0xAAu8; 64], &scales);
let tensor_avx512 = make_tensor(block.clone(), 200);
let tensor_ref = make_tensor(block, 200);
let input = vec![1.0f32; 200];
let mut out_avx512 = vec![0.0f32; 1];
let mut out_ref = vec![0.0f32; 1];
Q6_KAvx512
.gemv(&tensor_avx512, &input, &mut out_avx512)
.expect("avx512 gemv partial");
Q6KRef
.gemv(&tensor_ref, &input, &mut out_ref)
.expect("ref gemv partial");
assert!(
(out_avx512[0] - out_ref[0]).abs() < 1e-2,
"partial gemv mismatch: avx512={}, ref={}",
out_avx512[0],
out_ref[0]
);
}
}