#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512vnni", enable = "avx512bw")]
#[allow(unsafe_op_in_unsafe_fn)]
#[allow(clippy::similar_names)]
#[allow(clippy::too_many_lines)]
#[allow(dead_code)]
unsafe fn fused_q4k_q8k_dot_avx512vnni_v2(
q4k_data: &[u8],
q8k_scales: &[f32],
q8k_quants: &[i8],
) -> Result<f32> {
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
const SUPER_BLOCK_BYTES: usize = 144;
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let expected_values = num_super_blocks * QK_K;
if q8k_scales.len() < num_super_blocks || q8k_quants.len() < expected_values {
return Err(RealizarError::InvalidShape {
reason: "Q8_K buffer too small".to_string(),
});
}
let nibble_mask = _mm256_set1_epi8(0x0F_i8);
let ones_16 = _mm256_set1_epi16(1);
let mut total_acc = _mm256_setzero_ps();
for sb_idx in 0..num_super_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let q8_start = sb_idx * QK_K;
if sb_idx + 1 < num_super_blocks {
_mm_prefetch(
q4k_data
.as_ptr()
.add((sb_idx + 1) * SUPER_BLOCK_BYTES)
.cast::<i8>(),
_MM_HINT_T0,
);
_mm_prefetch(
q8k_quants.as_ptr().add((sb_idx + 1) * QK_K).cast::<i8>(),
_MM_HINT_T0,
);
}
let d = read_f16(&q4k_data[sb_start..sb_start + 2]);
let dmin = read_f16(&q4k_data[sb_start + 2..sb_start + 4]);
let mut scales_raw = [0u8; 12];
scales_raw.copy_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
let q8_scale = q8k_scales[sb_idx];
let d_q8 = d * q8_scale;
let dmin_q8 = dmin * q8_scale;
let qs_ptr = q4k_data.as_ptr().add(sb_start + 16);
let q8_ptr = q8k_quants.as_ptr().add(q8_start);
let mut block_dots_vec = _mm256_setzero_si256();
let mut block_q8sums_vec = _mm256_setzero_si256();
for chunk in 0..4 {
let j = chunk * 64;
let q_offset = j / 2;
let q4_bytes = _mm256_loadu_si256(qs_ptr.add(q_offset).cast::<__m256i>());
let q4_lo = _mm256_and_si256(q4_bytes, nibble_mask);
let q4_hi = _mm256_and_si256(_mm256_srli_epi16(q4_bytes, 4), nibble_mask);
let q8_lo = _mm256_loadu_si256(q8_ptr.add(j).cast::<__m256i>());
let q8_hi = _mm256_loadu_si256(q8_ptr.add(j + 32).cast::<__m256i>());
let prod_lo_i16 = _mm256_maddubs_epi16(q4_lo, q8_lo);
let prod_hi_i16 = _mm256_maddubs_epi16(q4_hi, q8_hi);
let prod_lo_i32 = _mm256_madd_epi16(prod_lo_i16, ones_16);
let prod_hi_i32 = _mm256_madd_epi16(prod_hi_i16, ones_16);
let prod_lo_h1 = _mm256_hadd_epi32(prod_lo_i32, prod_hi_i32); let prod_h2 = _mm256_hadd_epi32(prod_lo_h1, prod_lo_h1);
let lane0 = _mm256_castsi256_si128(prod_h2);
let lane1 = _mm256_extracti128_si256(prod_h2, 1);
let sums_128 = _mm_add_epi32(lane0, lane1);
match chunk {
0 => {
block_dots_vec =
_mm256_insert_epi32(block_dots_vec, _mm_extract_epi32(sums_128, 0), 0);
block_dots_vec =
_mm256_insert_epi32(block_dots_vec, _mm_extract_epi32(sums_128, 1), 1);
},
1 => {
block_dots_vec =
_mm256_insert_epi32(block_dots_vec, _mm_extract_epi32(sums_128, 0), 2);
block_dots_vec =
_mm256_insert_epi32(block_dots_vec, _mm_extract_epi32(sums_128, 1), 3);
},
2 => {
block_dots_vec =
_mm256_insert_epi32(block_dots_vec, _mm_extract_epi32(sums_128, 0), 4);
block_dots_vec =
_mm256_insert_epi32(block_dots_vec, _mm_extract_epi32(sums_128, 1), 5);
},
3 => {
block_dots_vec =
_mm256_insert_epi32(block_dots_vec, _mm_extract_epi32(sums_128, 0), 6);
block_dots_vec =
_mm256_insert_epi32(block_dots_vec, _mm_extract_epi32(sums_128, 1), 7);
},
_ => unreachable!(),
}
let bias = _mm256_set1_epi8(-128_i8);
let _q8_lo_u = _mm256_sub_epi8(q8_lo, bias); let _q8_hi_u = _mm256_sub_epi8(q8_hi, bias);
let q8_lo_i16_a = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8_lo));
let q8_lo_i16_b = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_lo, 1));
let q8_hi_i16_a = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8_hi));
let q8_hi_i16_b = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_hi, 1));
let q8_lo_i32_a = _mm256_madd_epi16(q8_lo_i16_a, ones_16);
let q8_lo_i32_b = _mm256_madd_epi16(q8_lo_i16_b, ones_16);
let q8_hi_i32_a = _mm256_madd_epi16(q8_hi_i16_a, ones_16);
let q8_hi_i32_b = _mm256_madd_epi16(q8_hi_i16_b, ones_16);
let q8_lo_sum = _mm256_add_epi32(q8_lo_i32_a, q8_lo_i32_b);
let q8_hi_sum = _mm256_add_epi32(q8_hi_i32_a, q8_hi_i32_b);
let q8_h1 = _mm256_hadd_epi32(q8_lo_sum, q8_hi_sum);
let q8_h2 = _mm256_hadd_epi32(q8_h1, q8_h1);
let q8_lane0 = _mm256_castsi256_si128(q8_h2);
let q8_lane1 = _mm256_extracti128_si256(q8_h2, 1);
let q8_sums_128 = _mm_add_epi32(q8_lane0, q8_lane1);
match chunk {
0 => {
block_q8sums_vec =
_mm256_insert_epi32(block_q8sums_vec, _mm_extract_epi32(q8_sums_128, 0), 0);
block_q8sums_vec =
_mm256_insert_epi32(block_q8sums_vec, _mm_extract_epi32(q8_sums_128, 1), 1);
},
1 => {
block_q8sums_vec =
_mm256_insert_epi32(block_q8sums_vec, _mm_extract_epi32(q8_sums_128, 0), 2);
block_q8sums_vec =
_mm256_insert_epi32(block_q8sums_vec, _mm_extract_epi32(q8_sums_128, 1), 3);
},
2 => {
block_q8sums_vec =
_mm256_insert_epi32(block_q8sums_vec, _mm_extract_epi32(q8_sums_128, 0), 4);
block_q8sums_vec =
_mm256_insert_epi32(block_q8sums_vec, _mm_extract_epi32(q8_sums_128, 1), 5);
},
3 => {
block_q8sums_vec =
_mm256_insert_epi32(block_q8sums_vec, _mm_extract_epi32(q8_sums_128, 0), 6);
block_q8sums_vec =
_mm256_insert_epi32(block_q8sums_vec, _mm_extract_epi32(q8_sums_128, 1), 7);
},
_ => unreachable!(),
}
}
let mut scales = [0.0f32; 8];
let mut mins = [0.0f32; 8];
for i in 0..8 {
let (sc, m) = extract_scale_min(&scales_raw, i);
scales[i] = sc;
mins[i] = m;
}
let scales_vec = _mm256_loadu_ps(scales.as_ptr());
let mins_vec = _mm256_loadu_ps(mins.as_ptr());
let dots_f32 = _mm256_cvtepi32_ps(block_dots_vec);
let q8sums_f32 = _mm256_cvtepi32_ps(block_q8sums_vec);
let d_q8_vec = _mm256_set1_ps(d_q8);
let dmin_q8_vec = _mm256_set1_ps(dmin_q8);
let term1 = _mm256_mul_ps(d_q8_vec, _mm256_mul_ps(scales_vec, dots_f32));
let term2 = _mm256_mul_ps(dmin_q8_vec, _mm256_mul_ps(mins_vec, q8sums_f32));
let result = _mm256_sub_ps(term1, term2);
total_acc = _mm256_add_ps(total_acc, result);
}
let sum128 = _mm_add_ps(
_mm256_castps256_ps128(total_acc),
_mm256_extractf128_ps(total_acc, 1),
);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
Ok(_mm_cvtss_f32(sum32))
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512vnni", enable = "avx512bw")]
#[allow(unsafe_op_in_unsafe_fn)]
#[allow(dead_code)]
unsafe fn fused_q4k_q8k_dot_avx512vnni(
q4k_data: &[u8],
q8k_scales: &[f32],
q8k_quants: &[i8],
) -> Result<f32> {
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
const SUPER_BLOCK_BYTES: usize = 144;
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let expected_values = num_super_blocks * QK_K;
if q8k_scales.len() < num_super_blocks || q8k_quants.len() < expected_values {
return Err(RealizarError::InvalidShape {
reason: "Q8_K buffer too small".to_string(),
});
}
let _nibble_mask = _mm512_set1_epi8(0x0F_i8); let mut total_acc = 0.0f32;
for sb_idx in 0..num_super_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let q8_start = sb_idx * QK_K;
if sb_idx + 1 < num_super_blocks {
_mm_prefetch(
q4k_data
.as_ptr()
.add((sb_idx + 1) * SUPER_BLOCK_BYTES)
.cast::<i8>(),
_MM_HINT_T0,
);
_mm_prefetch(
q8k_quants
.as_ptr()
.add((sb_idx + 1) * QK_K)
.cast::<i8>(),
_MM_HINT_T0,
);
}
let d = read_f16(&q4k_data[sb_start..sb_start + 2]);
let dmin = read_f16(&q4k_data[sb_start + 2..sb_start + 4]);
let mut scales = [0u8; 12];
scales.copy_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
let q8_scale = q8k_scales[sb_idx];
let d_q8 = d * q8_scale;
let dmin_q8 = dmin * q8_scale;
let qs_ptr = q4k_data.as_ptr().add(sb_start + 16);
let q8_ptr = q8k_quants.as_ptr().add(q8_start);
for j in (0..QK_K).step_by(64) {
let q_offset = j / 2;
let is = j / 32;
let (sc1, m1) = extract_scale_min(&scales, is);
let (sc2, m2) = extract_scale_min(&scales, is + 1);
let q4_256 = _mm256_loadu_si256(qs_ptr.add(q_offset).cast::<__m256i>());
let _q4_512 = _mm512_castsi256_si512(q4_256);
let q4_lo_256 = _mm256_and_si256(q4_256, _mm256_set1_epi8(0x0F_i8));
let q4_hi_256 =
_mm256_and_si256(_mm256_srli_epi16(q4_256, 4), _mm256_set1_epi8(0x0F_i8));
let q8_lo_256 = _mm256_loadu_si256(q8_ptr.add(j).cast::<__m256i>());
let q8_hi_256 = _mm256_loadu_si256(q8_ptr.add(j + 32).cast::<__m256i>());
let q4_lo_512 = _mm512_castsi256_si512(q4_lo_256);
let q4_hi_512 = _mm512_castsi256_si512(q4_hi_256);
let q8_lo_512 = _mm512_castsi256_si512(q8_lo_256);
let q8_hi_512 = _mm512_castsi256_si512(q8_hi_256);
let acc_lo = _mm512_dpbusd_epi32(_mm512_setzero_si512(), q4_lo_512, q8_lo_512);
let acc_hi = _mm512_dpbusd_epi32(_mm512_setzero_si512(), q4_hi_512, q8_hi_512);
let acc_lo_256 = _mm512_castsi512_si256(acc_lo);
let acc_hi_256 = _mm512_castsi512_si256(acc_hi);
let sum_lo = horizontal_sum_epi32_256(acc_lo_256);
let sum_hi = horizontal_sum_epi32_256(acc_hi_256);
let q8_lo_256_i16_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8_lo_256));
let q8_lo_256_i16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_lo_256, 1));
let q8_sum_lo = horizontal_sum_epi16_256(q8_lo_256_i16_lo)
+ horizontal_sum_epi16_256(q8_lo_256_i16_hi);
let q8_hi_256_i16_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8_hi_256));
let q8_hi_256_i16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_hi_256, 1));
let q8_sum_hi = horizontal_sum_epi16_256(q8_hi_256_i16_lo)
+ horizontal_sum_epi16_256(q8_hi_256_i16_hi);
total_acc += d_q8 * sc1 * (sum_lo as f32) - dmin_q8 * m1 * (q8_sum_lo as f32);
total_acc += d_q8 * sc2 * (sum_hi as f32) - dmin_q8 * m2 * (q8_sum_hi as f32);
}
}
Ok(total_acc)
}