#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512vnni", enable = "avx512bw")]
#[allow(unsafe_op_in_unsafe_fn)]
#[allow(clippy::similar_names)]
#[allow(clippy::too_many_lines)]
#[allow(dead_code)]
unsafe fn fused_q4k_q8k_dot_avx512vnni_opt(
q4k_data: &[u8],
q8k_scales: &[f32],
q8k_quants: &[i8],
) -> Result<f32> {
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
const SUPER_BLOCK_BYTES: usize = 144;
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let expected_values = num_super_blocks * QK_K;
if q8k_scales.len() < num_super_blocks || q8k_quants.len() < expected_values {
return Err(RealizarError::InvalidShape {
reason: "Q8_K buffer too small".to_string(),
});
}
let nibble_mask = _mm256_set1_epi8(0x0F_i8);
let ones_16 = _mm256_set1_epi16(1);
let mut total_acc = 0.0f32;
for sb_idx in 0..num_super_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let q8_start = sb_idx * QK_K;
if sb_idx + 1 < num_super_blocks {
_mm_prefetch(
q4k_data
.as_ptr()
.add((sb_idx + 1) * SUPER_BLOCK_BYTES)
.cast::<i8>(),
_MM_HINT_T0,
);
_mm_prefetch(
q8k_quants.as_ptr().add((sb_idx + 1) * QK_K).cast::<i8>(),
_MM_HINT_T0,
);
}
let d = read_f16(&q4k_data[sb_start..sb_start + 2]);
let dmin = read_f16(&q4k_data[sb_start + 2..sb_start + 4]);
let mut scales_raw = [0u8; 12];
scales_raw.copy_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
let q8_scale = q8k_scales[sb_idx];
let d_q8 = d * q8_scale;
let dmin_q8 = dmin * q8_scale;
let qs_ptr = q4k_data.as_ptr().add(sb_start + 16);
let q8_ptr = q8k_quants.as_ptr().add(q8_start);
let mut block_dots = [0i32; 8];
let mut block_q8sums = [0i32; 8];
for chunk in 0..4 {
let j = chunk * 64;
let q_offset = j / 2;
let q4_bytes = _mm256_loadu_si256(qs_ptr.add(q_offset).cast::<__m256i>());
let q4_lo = _mm256_and_si256(q4_bytes, nibble_mask);
let q4_hi = _mm256_and_si256(_mm256_srli_epi16(q4_bytes, 4), nibble_mask);
let q8_lo = _mm256_loadu_si256(q8_ptr.add(j).cast::<__m256i>());
let q8_hi = _mm256_loadu_si256(q8_ptr.add(j + 32).cast::<__m256i>());
let prod_lo_i16 = _mm256_maddubs_epi16(q4_lo, q8_lo);
let prod_hi_i16 = _mm256_maddubs_epi16(q4_hi, q8_hi);
let prod_lo_i32 = _mm256_madd_epi16(prod_lo_i16, ones_16);
let prod_hi_i32 = _mm256_madd_epi16(prod_hi_i16, ones_16);
let prod_lo_128 = _mm_add_epi32(
_mm256_castsi256_si128(prod_lo_i32),
_mm256_extracti128_si256(prod_lo_i32, 1),
);
let prod_hi_128 = _mm_add_epi32(
_mm256_castsi256_si128(prod_hi_i32),
_mm256_extracti128_si256(prod_hi_i32, 1),
);
let prod_lo_64 = _mm_hadd_epi32(prod_lo_128, prod_hi_128);
let prod_32 = _mm_hadd_epi32(prod_lo_64, prod_lo_64);
let block_idx = chunk * 2;
block_dots[block_idx] = _mm_extract_epi32(prod_32, 0);
block_dots[block_idx + 1] = _mm_extract_epi32(prod_32, 1);
let q8_lo_i16_a = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8_lo));
let q8_lo_i16_b = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_lo, 1));
let q8_hi_i16_a = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8_hi));
let q8_hi_i16_b = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_hi, 1));
let q8_lo_i32_a = _mm256_madd_epi16(q8_lo_i16_a, _mm256_set1_epi16(1));
let q8_lo_i32_b = _mm256_madd_epi16(q8_lo_i16_b, _mm256_set1_epi16(1));
let q8_hi_i32_a = _mm256_madd_epi16(q8_hi_i16_a, _mm256_set1_epi16(1));
let q8_hi_i32_b = _mm256_madd_epi16(q8_hi_i16_b, _mm256_set1_epi16(1));
let q8_lo_sum = _mm256_add_epi32(q8_lo_i32_a, q8_lo_i32_b);
let q8_hi_sum = _mm256_add_epi32(q8_hi_i32_a, q8_hi_i32_b);
let q8_lo_128 = _mm_add_epi32(
_mm256_castsi256_si128(q8_lo_sum),
_mm256_extracti128_si256(q8_lo_sum, 1),
);
let q8_hi_128 = _mm_add_epi32(
_mm256_castsi256_si128(q8_hi_sum),
_mm256_extracti128_si256(q8_hi_sum, 1),
);
let q8_64 = _mm_hadd_epi32(q8_lo_128, q8_hi_128);
let q8_32 = _mm_hadd_epi32(q8_64, q8_64);
block_q8sums[block_idx] = _mm_extract_epi32(q8_32, 0);
block_q8sums[block_idx + 1] = _mm_extract_epi32(q8_32, 1);
}
let mut scales = [0.0f32; 8];
let mut mins = [0.0f32; 8];
for i in 0..8 {
let (sc, m) = extract_scale_min(&scales_raw, i);
scales[i] = sc;
mins[i] = m;
}
let scales_vec = _mm256_loadu_ps(scales.as_ptr());
let mins_vec = _mm256_loadu_ps(mins.as_ptr());
let dots_i32 = _mm256_loadu_si256(block_dots.as_ptr().cast::<__m256i>());
let q8sums_i32 = _mm256_loadu_si256(block_q8sums.as_ptr().cast::<__m256i>());
let dots_f32 = _mm256_cvtepi32_ps(dots_i32);
let q8sums_f32 = _mm256_cvtepi32_ps(q8sums_i32);
let d_q8_vec = _mm256_set1_ps(d_q8);
let dmin_q8_vec = _mm256_set1_ps(dmin_q8);
let term1 = _mm256_mul_ps(d_q8_vec, _mm256_mul_ps(scales_vec, dots_f32));
let term2 = _mm256_mul_ps(dmin_q8_vec, _mm256_mul_ps(mins_vec, q8sums_f32));
let result = _mm256_sub_ps(term1, term2);
let sum128 = _mm_add_ps(
_mm256_castps256_ps128(result),
_mm256_extractf128_ps(result, 1),
);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
total_acc += _mm_cvtss_f32(sum32);
}
Ok(total_acc)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
unsafe fn hsum_epi32_128(v: std::arch::x86_64::__m128i) -> i32 {
use std::arch::x86_64::{_mm_cvtsi128_si32, _mm_hadd_epi32};
let sum64 = _mm_hadd_epi32(v, v);
let sum32 = _mm_hadd_epi32(sum64, sum64);
_mm_cvtsi128_si32(sum32)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
unsafe fn hsum_epi32_256(v: std::arch::x86_64::__m256i) -> i32 {
use std::arch::x86_64::{_mm256_castsi256_si128, _mm256_extracti128_si256, _mm_add_epi32};
unsafe {
let lo = _mm256_castsi256_si128(v);
let hi = _mm256_extracti128_si256(v, 1);
hsum_epi32_128(_mm_add_epi32(lo, hi))
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
unsafe fn horizontal_sum_epi32_256(v: std::arch::x86_64::__m256i) -> i32 {
use std::arch::x86_64::{
_mm256_castsi256_si128, _mm256_extracti128_si256, _mm_add_epi32, _mm_cvtsi128_si32,
_mm_hadd_epi32,
};
let hi = _mm256_extracti128_si256(v, 1);
let lo = _mm256_castsi256_si128(v);
let sum128 = _mm_add_epi32(lo, hi);
let sum64 = _mm_hadd_epi32(sum128, sum128);
let sum32 = _mm_hadd_epi32(sum64, sum64);
_mm_cvtsi128_si32(sum32)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(unsafe_op_in_unsafe_fn)]
#[allow(dead_code)]
unsafe fn horizontal_sum_epi16_256(v: std::arch::x86_64::__m256i) -> i32 {
use std::arch::x86_64::{_mm256_madd_epi16, _mm256_set1_epi16};
let ones = _mm256_set1_epi16(1);
let sum_i32 = _mm256_madd_epi16(v, ones);
horizontal_sum_epi32_256(sum_i32)
}