#![cfg(all(feature = "simd-avx2", target_arch = "x86_64"))]
use core::arch::x86_64::*;
use crate::error::{QuantError, QuantResult};
use crate::simd::avx2::util::{f16_to_f32, hsum_f32_avx};
use crate::traits::QuantKernel;
use crate::types::QuantTensor;
pub const BLOCK_SIZE: usize = 256;
pub const BLOCK_BYTES: usize = 210;
#[allow(non_camel_case_types)]
pub struct Q6_KAvx2;
impl QuantKernel for Q6_KAvx2 {
fn dequant_block(&self, block: &[u8], output: &mut [f32]) -> QuantResult<()> {
if block.len() < BLOCK_BYTES {
return Err(QuantError::BufferTooSmall {
needed: BLOCK_BYTES,
available: block.len(),
});
}
if output.len() < BLOCK_SIZE {
return Err(QuantError::BufferTooSmall {
needed: BLOCK_SIZE,
available: output.len(),
});
}
unsafe { dequant_block_avx2(block, output) }
Ok(())
}
fn gemv(
&self,
quant_matrix: &QuantTensor,
input: &[f32],
output: &mut [f32],
) -> QuantResult<()> {
let n_rows = quant_matrix.shape[0];
let n_cols = if quant_matrix.shape.len() > 1 {
quant_matrix.shape[1]
} else {
quant_matrix.n_elements() / n_rows
};
if input.len() < n_cols {
return Err(QuantError::DimensionMismatch {
expected: n_cols,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(QuantError::DimensionMismatch {
expected: n_rows,
got: output.len(),
});
}
let blocks_per_row = n_cols.div_ceil(BLOCK_SIZE);
let row_bytes = blocks_per_row * BLOCK_BYTES;
for (row, out) in output.iter_mut().enumerate().take(n_rows) {
let row_start = row * row_bytes;
*out = unsafe {
gemv_row_avx2(
&quant_matrix.data[row_start..row_start + row_bytes],
input,
blocks_per_row,
n_cols,
)
};
}
Ok(())
}
fn gemm(
&self,
quant_matrix: &QuantTensor,
input: &[f32],
output: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> QuantResult<()> {
for row in 0..m {
let input_row = &input[row * k..(row + 1) * k];
let output_row = &mut output[row * n..(row + 1) * n];
self.gemv(quant_matrix, input_row, output_row)?;
}
Ok(())
}
fn block_size(&self) -> usize {
BLOCK_SIZE
}
fn block_bytes(&self) -> usize {
BLOCK_BYTES
}
fn name(&self) -> &'static str {
"Q6_K"
}
fn matvec_q8_fused(
&self,
weights: &[u8],
acts_q8: &[u8],
out: &mut [f32],
n_rows: usize,
n_cols: usize,
) -> QuantResult<()> {
if out.len() < n_rows {
return Err(QuantError::DimensionMismatch {
expected: n_rows,
got: out.len(),
});
}
let blocks_per_row = n_cols.div_ceil(BLOCK_SIZE);
let row_bytes = blocks_per_row * BLOCK_BYTES;
let q8_blocks_per_row = blocks_per_row * 8;
let acts_needed = q8_blocks_per_row * Q8_0_BLOCK_BYTES;
if weights.len() < n_rows * row_bytes {
return Err(QuantError::BufferTooSmall {
needed: n_rows * row_bytes,
available: weights.len(),
});
}
if acts_q8.len() < acts_needed {
return Err(QuantError::BufferTooSmall {
needed: acts_needed,
available: acts_q8.len(),
});
}
for row in 0..n_rows {
let row_start = row * row_bytes;
let row_sum = unsafe {
fused_q6_k_q8_0_row_avx2(
&weights[row_start..row_start + row_bytes],
acts_q8,
blocks_per_row,
n_cols,
)
};
out[row] += row_sum;
}
Ok(())
}
}
const Q8_0_BLOCK_BYTES: usize = 34;
#[target_feature(enable = "avx2,fma")]
unsafe fn fused_q6_k_q8_0_row_avx2(
row_data: &[u8],
acts_q8: &[u8],
blocks_per_row: usize,
n_cols: usize,
) -> f32 {
let mut row_sum = 0.0f32;
for blk in 0..blocks_per_row {
let bo = blk * BLOCK_BYTES;
let block = &row_data[bo..bo + BLOCK_BYTES];
let ql = &block[0..128];
let qh = &block[128..192];
let scales = &block[192..208];
let d = f16_to_f32(&block[208..]);
let input_offset = blk * BLOCK_SIZE;
let cols_in_block = (n_cols - input_offset).min(BLOCK_SIZE);
for group in 0..2 {
let ql_off = group * 64;
let qh_off = group * 32;
let sc_off = group * 8;
let in_off = group * 128;
for l in 0..32 {
let is = l / 16;
let q1 = ((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4)) as i32 - 32;
let q2 =
((ql[ql_off + l + 32] & 0x0F) | (((qh[qh_off + l] >> 2) & 3) << 4)) as i32 - 32;
let q3 = ((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4)) as i32 - 32;
let q4 =
((ql[ql_off + l + 32] >> 4) | (((qh[qh_off + l] >> 6) & 3) << 4)) as i32 - 32;
let s0 = d * scales[sc_off + is] as i8 as f32;
let s1 = d * scales[sc_off + is + 2] as i8 as f32;
let s2 = d * scales[sc_off + is + 4] as i8 as f32;
let s3 = d * scales[sc_off + is + 6] as i8 as f32;
let c0 = in_off + l;
let c1 = in_off + l + 32;
let c2 = in_off + l + 64;
let c3 = in_off + l + 96;
let sample_q8 = |col: usize| -> Option<f32> {
if col >= cols_in_block {
return None;
}
let q8_blk = blk * 8 + col / 32;
let q8_lane = col % 32;
let ab = &acts_q8[q8_blk * Q8_0_BLOCK_BYTES..(q8_blk + 1) * Q8_0_BLOCK_BYTES];
let d_a = f16_to_f32(ab);
let q_a = ab[2 + q8_lane] as i8 as f32;
Some(d_a * q_a)
};
if let Some(a0) = sample_q8(c0) {
row_sum += s0 * q1 as f32 * a0;
}
if let Some(a1) = sample_q8(c1) {
row_sum += s1 * q2 as f32 * a1;
}
if let Some(a2) = sample_q8(c2) {
row_sum += s2 * q3 as f32 * a2;
}
if let Some(a3) = sample_q8(c3) {
row_sum += s3 * q4 as f32 * a3;
}
}
}
}
row_sum
}
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn decode_q6_group(
ql_ptr: *const u8,
qh_ptr: *const u8,
) -> (__m128i, __m128i, __m128i, __m128i) {
let ql0 = _mm_loadu_si128(ql_ptr as *const __m128i); let ql1 = _mm_loadu_si128(ql_ptr.add(32) as *const __m128i);
let qh_raw = _mm_loadu_si128(qh_ptr as *const __m128i);
let mask4 = _mm_set1_epi8(0x0F_u8 as i8);
let mask2 = _mm_set1_epi8(0x03_u8 as i8);
let ql0_lo = _mm_and_si128(ql0, mask4); let ql1_lo = _mm_and_si128(ql1, mask4);
let ql0_hi = _mm_and_si128(_mm_srli_epi16(ql0, 4), mask4); let ql1_hi = _mm_and_si128(_mm_srli_epi16(ql1, 4), mask4);
let qh_sh0 = _mm_and_si128(qh_raw, mask2); let qh_sh2 = _mm_and_si128(_mm_srli_epi16(qh_raw, 2), mask2); let qh_sh4 = _mm_and_si128(_mm_srli_epi16(qh_raw, 4), mask2); let qh_sh6 = _mm_and_si128(_mm_srli_epi16(qh_raw, 6), mask2);
let qh_hi0 = _mm_slli_epi16(qh_sh0, 4); let qh_hi2 = _mm_slli_epi16(qh_sh2, 4);
let qh_hi4 = _mm_slli_epi16(qh_sh4, 4);
let qh_hi6 = _mm_slli_epi16(qh_sh6, 4);
let q1 = _mm_or_si128(ql0_lo, qh_hi0);
let q2 = _mm_or_si128(ql1_lo, qh_hi2);
let q3 = _mm_or_si128(ql0_hi, qh_hi4);
let q4 = _mm_or_si128(ql1_hi, qh_hi6);
(q1, q2, q3, q4)
}
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn q6_bytes_to_f32_pair(q: __m128i) -> (__m256, __m256) {
let offset = _mm256_set1_epi32(32);
let q_lo = _mm256_cvtepu8_epi32(q);
let f_lo = _mm256_cvtepi32_ps(_mm256_sub_epi32(q_lo, offset));
let q_hi_bytes = _mm_srli_si128(q, 8);
let q_hi = _mm256_cvtepu8_epi32(q_hi_bytes);
let f_hi = _mm256_cvtepi32_ps(_mm256_sub_epi32(q_hi, offset));
(f_lo, f_hi)
}
#[target_feature(enable = "avx2,fma")]
unsafe fn dequant_block_avx2(block: &[u8], output: &mut [f32]) {
let ql = &block[0..128];
let qh = &block[128..192];
let scales = &block[192..208];
let d = f16_to_f32(&block[208..]);
for group in 0..2usize {
let ql_off = group * 64; let qh_off = group * 32; let sc_off = group * 8; let out_off = group * 128;
let (q1_a, q2_a, q3_a, q4_a) =
decode_q6_group(ql.as_ptr().add(ql_off), qh.as_ptr().add(qh_off));
let (q1_b, q2_b, q3_b, q4_b) =
decode_q6_group(ql.as_ptr().add(ql_off + 16), qh.as_ptr().add(qh_off + 16));
let s0a = d * (*scales.get_unchecked(sc_off)) as i8 as f32;
let s1a = d * (*scales.get_unchecked(sc_off + 2)) as i8 as f32;
let s2a = d * (*scales.get_unchecked(sc_off + 4)) as i8 as f32;
let s3a = d * (*scales.get_unchecked(sc_off + 6)) as i8 as f32;
let s0b = d * (*scales.get_unchecked(sc_off + 1)) as i8 as f32;
let s1b = d * (*scales.get_unchecked(sc_off + 3)) as i8 as f32;
let s2b = d * (*scales.get_unchecked(sc_off + 5)) as i8 as f32;
let s3b = d * (*scales.get_unchecked(sc_off + 7)) as i8 as f32;
let vs0a = _mm256_set1_ps(s0a);
let vs1a = _mm256_set1_ps(s1a);
let vs2a = _mm256_set1_ps(s2a);
let vs3a = _mm256_set1_ps(s3a);
let vs0b = _mm256_set1_ps(s0b);
let vs1b = _mm256_set1_ps(s1b);
let vs2b = _mm256_set1_ps(s2b);
let vs3b = _mm256_set1_ps(s3b);
let (q1a_lo, q1a_hi) = q6_bytes_to_f32_pair(q1_a);
let (q1b_lo, q1b_hi) = q6_bytes_to_f32_pair(q1_b);
let ptr_q1 = output.as_mut_ptr().add(out_off);
_mm256_storeu_ps(ptr_q1, _mm256_mul_ps(vs0a, q1a_lo));
_mm256_storeu_ps(ptr_q1.add(8), _mm256_mul_ps(vs0a, q1a_hi));
_mm256_storeu_ps(ptr_q1.add(16), _mm256_mul_ps(vs0b, q1b_lo));
_mm256_storeu_ps(ptr_q1.add(24), _mm256_mul_ps(vs0b, q1b_hi));
let (q2a_lo, q2a_hi) = q6_bytes_to_f32_pair(q2_a);
let (q2b_lo, q2b_hi) = q6_bytes_to_f32_pair(q2_b);
let ptr_q2 = output.as_mut_ptr().add(out_off + 32);
_mm256_storeu_ps(ptr_q2, _mm256_mul_ps(vs1a, q2a_lo));
_mm256_storeu_ps(ptr_q2.add(8), _mm256_mul_ps(vs1a, q2a_hi));
_mm256_storeu_ps(ptr_q2.add(16), _mm256_mul_ps(vs1b, q2b_lo));
_mm256_storeu_ps(ptr_q2.add(24), _mm256_mul_ps(vs1b, q2b_hi));
let (q3a_lo, q3a_hi) = q6_bytes_to_f32_pair(q3_a);
let (q3b_lo, q3b_hi) = q6_bytes_to_f32_pair(q3_b);
let ptr_q3 = output.as_mut_ptr().add(out_off + 64);
_mm256_storeu_ps(ptr_q3, _mm256_mul_ps(vs2a, q3a_lo));
_mm256_storeu_ps(ptr_q3.add(8), _mm256_mul_ps(vs2a, q3a_hi));
_mm256_storeu_ps(ptr_q3.add(16), _mm256_mul_ps(vs2b, q3b_lo));
_mm256_storeu_ps(ptr_q3.add(24), _mm256_mul_ps(vs2b, q3b_hi));
let (q4a_lo, q4a_hi) = q6_bytes_to_f32_pair(q4_a);
let (q4b_lo, q4b_hi) = q6_bytes_to_f32_pair(q4_b);
let ptr_q4 = output.as_mut_ptr().add(out_off + 96);
_mm256_storeu_ps(ptr_q4, _mm256_mul_ps(vs3a, q4a_lo));
_mm256_storeu_ps(ptr_q4.add(8), _mm256_mul_ps(vs3a, q4a_hi));
_mm256_storeu_ps(ptr_q4.add(16), _mm256_mul_ps(vs3b, q4b_lo));
_mm256_storeu_ps(ptr_q4.add(24), _mm256_mul_ps(vs3b, q4b_hi));
}
}
#[target_feature(enable = "avx2,fma")]
unsafe fn gemv_row_avx2(
row_data: &[u8],
input: &[f32],
blocks_per_row: usize,
n_cols: usize,
) -> f32 {
let mut row_sum = 0.0f32;
for blk in 0..blocks_per_row {
let block_offset = blk * BLOCK_BYTES;
let block = &row_data[block_offset..block_offset + BLOCK_BYTES];
let input_offset = blk * BLOCK_SIZE;
let remaining = n_cols.saturating_sub(input_offset);
let ql = &block[0..128];
let qh = &block[128..192];
let scales = &block[192..208];
let d = f16_to_f32(&block[208..]);
if remaining >= BLOCK_SIZE {
let mut block_acc = _mm256_setzero_ps();
for group in 0..2usize {
let ql_off = group * 64;
let qh_off = group * 32;
let sc_off = group * 8;
let w_off = input_offset + group * 128;
let s0a = d * (*scales.get_unchecked(sc_off)) as i8 as f32;
let s1a = d * (*scales.get_unchecked(sc_off + 2)) as i8 as f32;
let s2a = d * (*scales.get_unchecked(sc_off + 4)) as i8 as f32;
let s3a = d * (*scales.get_unchecked(sc_off + 6)) as i8 as f32;
let s0b = d * (*scales.get_unchecked(sc_off + 1)) as i8 as f32;
let s1b = d * (*scales.get_unchecked(sc_off + 3)) as i8 as f32;
let s2b = d * (*scales.get_unchecked(sc_off + 5)) as i8 as f32;
let s3b = d * (*scales.get_unchecked(sc_off + 7)) as i8 as f32;
let vs0a = _mm256_set1_ps(s0a);
let vs1a = _mm256_set1_ps(s1a);
let vs2a = _mm256_set1_ps(s2a);
let vs3a = _mm256_set1_ps(s3a);
let vs0b = _mm256_set1_ps(s0b);
let vs1b = _mm256_set1_ps(s1b);
let vs2b = _mm256_set1_ps(s2b);
let vs3b = _mm256_set1_ps(s3b);
let (q1_a, q2_a, q3_a, q4_a) =
decode_q6_group(ql.as_ptr().add(ql_off), qh.as_ptr().add(qh_off));
let (q1_b, q2_b, q3_b, q4_b) =
decode_q6_group(ql.as_ptr().add(ql_off + 16), qh.as_ptr().add(qh_off + 16));
let off32 = _mm256_set1_epi32(32);
let inp_q1 = input.as_ptr().add(w_off);
let q1a_lo_i = _mm256_sub_epi32(_mm256_cvtepu8_epi32(q1_a), off32);
let q1a_lo_f = _mm256_cvtepi32_ps(q1a_lo_i);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs0a, q1a_lo_f),
_mm256_loadu_ps(inp_q1),
block_acc,
);
let q1a_hi_i =
_mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_srli_si128(q1_a, 8)), off32);
let q1a_hi_f = _mm256_cvtepi32_ps(q1a_hi_i);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs0a, q1a_hi_f),
_mm256_loadu_ps(inp_q1.add(8)),
block_acc,
);
let q1b_lo_i = _mm256_sub_epi32(_mm256_cvtepu8_epi32(q1_b), off32);
let q1b_lo_f = _mm256_cvtepi32_ps(q1b_lo_i);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs0b, q1b_lo_f),
_mm256_loadu_ps(inp_q1.add(16)),
block_acc,
);
let q1b_hi_i =
_mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_srli_si128(q1_b, 8)), off32);
let q1b_hi_f = _mm256_cvtepi32_ps(q1b_hi_i);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs0b, q1b_hi_f),
_mm256_loadu_ps(inp_q1.add(24)),
block_acc,
);
let inp_q2 = input.as_ptr().add(w_off + 32);
let q2a_lo_i = _mm256_sub_epi32(_mm256_cvtepu8_epi32(q2_a), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs1a, _mm256_cvtepi32_ps(q2a_lo_i)),
_mm256_loadu_ps(inp_q2),
block_acc,
);
let q2a_hi_i =
_mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_srli_si128(q2_a, 8)), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs1a, _mm256_cvtepi32_ps(q2a_hi_i)),
_mm256_loadu_ps(inp_q2.add(8)),
block_acc,
);
let q2b_lo_i = _mm256_sub_epi32(_mm256_cvtepu8_epi32(q2_b), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs1b, _mm256_cvtepi32_ps(q2b_lo_i)),
_mm256_loadu_ps(inp_q2.add(16)),
block_acc,
);
let q2b_hi_i =
_mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_srli_si128(q2_b, 8)), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs1b, _mm256_cvtepi32_ps(q2b_hi_i)),
_mm256_loadu_ps(inp_q2.add(24)),
block_acc,
);
let inp_q3 = input.as_ptr().add(w_off + 64);
let q3a_lo_i = _mm256_sub_epi32(_mm256_cvtepu8_epi32(q3_a), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs2a, _mm256_cvtepi32_ps(q3a_lo_i)),
_mm256_loadu_ps(inp_q3),
block_acc,
);
let q3a_hi_i =
_mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_srli_si128(q3_a, 8)), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs2a, _mm256_cvtepi32_ps(q3a_hi_i)),
_mm256_loadu_ps(inp_q3.add(8)),
block_acc,
);
let q3b_lo_i = _mm256_sub_epi32(_mm256_cvtepu8_epi32(q3_b), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs2b, _mm256_cvtepi32_ps(q3b_lo_i)),
_mm256_loadu_ps(inp_q3.add(16)),
block_acc,
);
let q3b_hi_i =
_mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_srli_si128(q3_b, 8)), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs2b, _mm256_cvtepi32_ps(q3b_hi_i)),
_mm256_loadu_ps(inp_q3.add(24)),
block_acc,
);
let inp_q4 = input.as_ptr().add(w_off + 96);
let q4a_lo_i = _mm256_sub_epi32(_mm256_cvtepu8_epi32(q4_a), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs3a, _mm256_cvtepi32_ps(q4a_lo_i)),
_mm256_loadu_ps(inp_q4),
block_acc,
);
let q4a_hi_i =
_mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_srli_si128(q4_a, 8)), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs3a, _mm256_cvtepi32_ps(q4a_hi_i)),
_mm256_loadu_ps(inp_q4.add(8)),
block_acc,
);
let q4b_lo_i = _mm256_sub_epi32(_mm256_cvtepu8_epi32(q4_b), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs3b, _mm256_cvtepi32_ps(q4b_lo_i)),
_mm256_loadu_ps(inp_q4.add(16)),
block_acc,
);
let q4b_hi_i =
_mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_srli_si128(q4_b, 8)), off32);
block_acc = _mm256_fmadd_ps(
_mm256_mul_ps(vs3b, _mm256_cvtepi32_ps(q4b_hi_i)),
_mm256_loadu_ps(inp_q4.add(24)),
block_acc,
);
}
row_sum += hsum_f32_avx(block_acc);
} else if remaining > 0 {
let mut partial_sum = 0.0f32;
for group in 0..2usize {
let ql_off = group * 64;
let qh_off = group * 32;
let sc_off = group * 8;
let in_off = input_offset + group * 128;
for l in 0..32 {
let is = l / 16;
let ql_l = *ql.get_unchecked(ql_off + l);
let ql_l32 = *ql.get_unchecked(ql_off + l + 32);
let qh_l = *qh.get_unchecked(qh_off + l);
let q1 = ((ql_l & 0x0F) | ((qh_l & 3) << 4)) as i32 - 32;
let q2 = ((ql_l32 & 0x0F) | (((qh_l >> 2) & 3) << 4)) as i32 - 32;
let q3 = ((ql_l >> 4) | (((qh_l >> 4) & 3) << 4)) as i32 - 32;
let q4 = ((ql_l32 >> 4) | (((qh_l >> 6) & 3) << 4)) as i32 - 32;
let s0 = d * (*scales.get_unchecked(sc_off + is)) as i8 as f32;
let s1 = d * (*scales.get_unchecked(sc_off + is + 2)) as i8 as f32;
let s2 = d * (*scales.get_unchecked(sc_off + is + 4)) as i8 as f32;
let s3 = d * (*scales.get_unchecked(sc_off + is + 6)) as i8 as f32;
let idx0 = in_off + l;
let idx1 = in_off + l + 32;
let idx2 = in_off + l + 64;
let idx3 = in_off + l + 96;
if idx0 < n_cols {
partial_sum += s0 * q1 as f32 * input[idx0];
}
if idx1 < n_cols {
partial_sum += s1 * q2 as f32 * input[idx1];
}
if idx2 < n_cols {
partial_sum += s2 * q3 as f32 * input[idx2];
}
if idx3 < n_cols {
partial_sum += s3 * q4 as f32 * input[idx3];
}
}
}
row_sum += partial_sum;
}
}
row_sum
}
#[cfg(all(test, target_arch = "x86_64", feature = "simd-avx2"))]
mod tests {
use super::*;
use crate::reference::q6_k::Q6KRef;
fn make_q6k_block(d: f32, ql: &[u8; 128], qh: &[u8; 64], scales: &[u8; 16]) -> Vec<u8> {
let mut block = Vec::with_capacity(BLOCK_BYTES);
block.extend_from_slice(ql);
block.extend_from_slice(qh);
block.extend_from_slice(scales);
block.extend_from_slice(&half::f16::from_f32(d).to_bits().to_le_bytes());
block
}
fn make_tensor(block: Vec<u8>, n_cols: usize) -> crate::types::QuantTensor {
crate::types::QuantTensor::new(block, vec![1, n_cols], oxillama_gguf::GgufTensorType::Q6K)
}
#[test]
fn test_q6k_avx2_dequant_matches_reference_zero() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let block = make_q6k_block(0.0, &[0; 128], &[0; 64], &[0; 16]);
let mut out_avx2 = vec![0.0f32; 256];
let mut out_ref = vec![0.0f32; 256];
Q6_KAvx2
.dequant_block(&block, &mut out_avx2)
.expect("avx2 dequant");
Q6KRef
.dequant_block(&block, &mut out_ref)
.expect("ref dequant");
for (i, (&a, &r)) in out_avx2.iter().zip(out_ref.iter()).enumerate() {
assert!(
(a - r).abs() < 1e-5,
"dequant mismatch [zero] at index {i}: avx2={a}, ref={r}"
);
}
}
#[test]
fn test_q6k_avx2_dequant_matches_reference_quant32() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let qh = [0xAAu8; 64];
let scales: [u8; 16] = [1; 16];
let block = make_q6k_block(1.0, &[0; 128], &qh, &scales);
let mut out_avx2 = vec![0.0f32; 256];
let mut out_ref = vec![0.0f32; 256];
Q6_KAvx2
.dequant_block(&block, &mut out_avx2)
.expect("avx2 dequant");
Q6KRef
.dequant_block(&block, &mut out_ref)
.expect("ref dequant");
for (i, (&a, &r)) in out_avx2.iter().zip(out_ref.iter()).enumerate() {
assert!(
(a - r).abs() < 1e-4,
"dequant mismatch [quant32] at index {i}: avx2={a}, ref={r}"
);
}
}
#[test]
fn test_q6k_avx2_dequant_matches_reference_varied() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let mut ql = [0u8; 128];
for (i, b) in ql.iter_mut().enumerate() {
*b = ((i * 7 + 3) & 0xFF) as u8;
}
let mut qh = [0u8; 64];
for (i, b) in qh.iter_mut().enumerate() {
*b = ((i * 13 + 5) & 0xFF) as u8;
}
let mut scales = [0u8; 16];
for (i, s) in scales.iter_mut().enumerate() {
*s = (i as i8 * 3 - 8) as u8;
}
let block = make_q6k_block(0.5, &ql, &qh, &scales);
let mut out_avx2 = vec![0.0f32; 256];
let mut out_ref = vec![0.0f32; 256];
Q6_KAvx2
.dequant_block(&block, &mut out_avx2)
.expect("avx2 dequant");
Q6KRef
.dequant_block(&block, &mut out_ref)
.expect("ref dequant");
for (i, (&a, &r)) in out_avx2.iter().zip(out_ref.iter()).enumerate() {
assert!(
(a - r).abs() < 1e-4,
"dequant mismatch [varied] at index {i}: avx2={a}, ref={r}"
);
}
}
#[test]
fn test_q6k_avx2_gemv_matches_reference() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let mut ql = [0u8; 128];
for (i, b) in ql.iter_mut().enumerate() {
*b = ((i * 7 + 3) & 0xFF) as u8;
}
let mut qh = [0u8; 64];
for (i, b) in qh.iter_mut().enumerate() {
*b = ((i * 13 + 5) & 0xFF) as u8;
}
let mut scales = [0u8; 16];
for (i, s) in scales.iter_mut().enumerate() {
*s = (i as i8 * 3 - 8) as u8;
}
let block = make_q6k_block(0.5, &ql, &qh, &scales);
let tensor_avx2 = make_tensor(block.clone(), 256);
let tensor_ref = make_tensor(block, 256);
let input: Vec<f32> = (0..256).map(|i| (i as f32 * 0.01) - 1.28).collect();
let mut out_avx2 = vec![0.0f32; 1];
let mut out_ref = vec![0.0f32; 1];
Q6_KAvx2
.gemv(&tensor_avx2, &input, &mut out_avx2)
.expect("avx2 gemv");
Q6KRef
.gemv(&tensor_ref, &input, &mut out_ref)
.expect("ref gemv");
assert!(
(out_avx2[0] - out_ref[0]).abs() < 1e-2,
"gemv mismatch: avx2={}, ref={}",
out_avx2[0],
out_ref[0]
);
}
#[test]
fn test_q6k_avx2_gemv_partial_block() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let scales = [1i8 as u8; 16];
let block = make_q6k_block(1.0, &[0; 128], &[0xAAu8; 64], &scales);
let tensor_avx2 = make_tensor(block.clone(), 200);
let tensor_ref = make_tensor(block, 200);
let input = vec![1.0f32; 200];
let mut out_avx2 = vec![0.0f32; 1];
let mut out_ref = vec![0.0f32; 1];
Q6_KAvx2
.gemv(&tensor_avx2, &input, &mut out_avx2)
.expect("avx2 gemv partial");
Q6KRef
.gemv(&tensor_ref, &input, &mut out_ref)
.expect("ref gemv partial");
assert!(
(out_avx2[0] - out_ref[0]).abs() < 1e-2,
"partial gemv mismatch: avx2={}, ref={}",
out_avx2[0],
out_ref[0]
);
}
fn make_q8_0_block(scale: f32, values: &[i8; 32]) -> Vec<u8> {
let mut block = Vec::with_capacity(34);
block.extend_from_slice(&half::f16::from_f32(scale).to_bits().to_le_bytes());
for &v in values {
block.push(v as u8);
}
block
}
fn make_q8_acts(n_q8_blocks: usize, scale: f32, values: &[i8; 32]) -> Vec<u8> {
make_q8_0_block(scale, values).repeat(n_q8_blocks)
}
#[test]
fn test_q6k_avx2_fused_matches_reference_single_block() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let mut ql = [0u8; 128];
for (i, q) in ql.iter_mut().enumerate() {
*q = ((i * 5 + 11) & 0xFF) as u8;
}
let mut qh = [0u8; 64];
for (i, h) in qh.iter_mut().enumerate() {
*h = ((i * 13 + 7) & 0xFF) as u8;
}
let scales: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let w_block = make_q6k_block(0.01, &ql, &qh, &scales);
let act_vals: [i8; 32] = [
1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 0, 1, -1, 2, -2, 3, -3,
4, -4, 5, -5, 6, -6, 7, -7, 8,
];
let acts = make_q8_acts(8, 0.1, &act_vals);
let mut out_avx2 = vec![0.0f32; 1];
let mut out_ref = vec![0.0f32; 1];
Q6_KAvx2
.matvec_q8_fused(&w_block, &acts, &mut out_avx2, 1, 256)
.expect("avx2 fused single block");
Q6KRef
.matvec_q8_fused(&w_block, &acts, &mut out_ref, 1, 256)
.expect("ref fused single block");
let err = (out_avx2[0] - out_ref[0]).abs();
assert!(
err < 1e-3,
"fused single-block mismatch: avx2={} ref={} err={}",
out_avx2[0],
out_ref[0],
err
);
}
#[test]
fn test_q6k_avx2_fused_multi_row() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let n_rows = 3usize;
let n_cols = 512usize;
let blocks_per_row = 2usize;
let q8_blocks_per_row = blocks_per_row * 8;
let mut all_weights = Vec::new();
for r in 0..n_rows {
for b in 0..blocks_per_row {
let mut ql = [0u8; 128];
for (i, q) in ql.iter_mut().enumerate() {
*q = ((r * 17 + b * 11 + i * 5) & 0xFF) as u8;
}
let mut qh = [0u8; 64];
for (i, h) in qh.iter_mut().enumerate() {
*h = ((r * 13 + b * 19 + i * 3) & 0xFF) as u8;
}
let scales: [u8; 16] =
core::array::from_fn(|i| ((r * 7 + b * 11 + i * 13 + 5) & 0xFF) as u8);
all_weights.extend(make_q6k_block(0.01 + r as f32 * 0.005, &ql, &qh, &scales));
}
}
let act_vals: [i8; 32] = [
2, -3, 5, -7, 1, -1, 4, -4, 6, -6, 3, -3, 2, -2, 1, -1, 8, -8, 7, -7, 6, -6, 5, -5, 4,
-4, 3, -3, 2, -2, 1, -1,
];
let acts = make_q8_acts(q8_blocks_per_row, 0.05, &act_vals);
let mut out_avx2 = vec![0.0f32; n_rows];
let mut out_ref = vec![0.0f32; n_rows];
Q6_KAvx2
.matvec_q8_fused(&all_weights, &acts, &mut out_avx2, n_rows, n_cols)
.expect("avx2 fused multi-row");
Q6KRef
.matvec_q8_fused(&all_weights, &acts, &mut out_ref, n_rows, n_cols)
.expect("ref fused multi-row");
for i in 0..n_rows {
let err = (out_avx2[i] - out_ref[i]).abs();
assert!(
err < 1e-3,
"fused multi-row row {i}: avx2={} ref={} err={}",
out_avx2[i],
out_ref[i],
err
);
}
}
#[test]
fn test_q6k_avx2_fused_accumulate_semantics() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let w_block = make_q6k_block(0.0, &[0u8; 128], &[0u8; 64], &[0u8; 16]);
let acts = make_q8_acts(8, 0.0, &[0i8; 32]);
let mut out = vec![77.0f32; 1];
Q6_KAvx2
.matvec_q8_fused(&w_block, &acts, &mut out, 1, 256)
.expect("avx2 fused accumulate");
assert!(
(out[0] - 77.0).abs() < 1e-5,
"accumulate semantics broken: expected 77.0, got {}",
out[0]
);
}
}