use super::dequant::read_f16;
use super::simd::extract_scale_min;
use super::types::QK_K;
use crate::error::{RealizarError, Result};
pub fn fused_q4k_dot(q4k_data: &[u8], activations: &[f32]) -> Result<f32> {
const SUPER_BLOCK_BYTES: usize = 144;
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of super-block size {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let expected_values = num_super_blocks * QK_K;
if activations.len() != expected_values {
return Err(RealizarError::InvalidShape {
reason: format!(
"Activation length {} doesn't match Q4_K values count {}",
activations.len(),
expected_values
),
});
}
if num_super_blocks == 0 {
return Ok(0.0);
}
contract_pre_matmul!(activations);
let mut acc = 0.0f32;
let mut activation_idx = 0;
for sb_idx in 0..num_super_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
let d = read_f16(&q4k_data[sb_start..sb_start + 2]);
let dmin = read_f16(&q4k_data[sb_start + 2..sb_start + 4]);
let mut scales = [0u8; 12];
scales.copy_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
let qs_start = sb_start + 16;
let qs = &q4k_data[qs_start..qs_start + 128];
for j in (0..QK_K).step_by(64) {
let q = &qs[j / 2..j / 2 + 32];
let is = j / 32;
let (sc1, m1) = extract_scale_min(&scales, is);
let d1 = d * sc1;
let dm1 = dmin * m1;
let (sc2, m2) = extract_scale_min(&scales, is + 1);
let d2 = d * sc2;
let dm2 = dmin * m2;
for &byte in q {
let q_val = (byte & 0x0F) as f32;
let value = d1 * q_val - dm1;
acc += value * activations[activation_idx];
activation_idx += 1;
}
for &byte in q {
let q_val = (byte >> 4) as f32;
let value = d2 * q_val - dm2;
acc += value * activations[activation_idx];
activation_idx += 1;
}
}
}
Ok(acc)
}
pub fn fused_q4k_dot_simd(q4k_data: &[u8], activations: &[f32]) -> Result<f32> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { fused_q4k_dot_avx2(q4k_data, activations) };
}
}
fused_q4k_dot(q4k_data, activations)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vnni")]
#[allow(unsafe_op_in_unsafe_fn)]
#[inline]
#[allow(dead_code)]
unsafe fn avx512_quantize_dot(
act_slice: &[f32],
q_nibbles_256: std::arch::x86_64::__m256i,
) -> (i32, f32) {
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
let act_max = act_slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let act_min = act_slice.iter().copied().fold(f32::INFINITY, f32::min);
let act_scale = if act_max > act_min {
127.0 / (act_max - act_min)
} else {
1.0
};
let mut act_i8 = [0i8; 32];
for (i, &a) in act_slice.iter().enumerate() {
act_i8[i] = ((a - act_min) * act_scale).round() as i8;
}
let act_vec = _mm256_loadu_si256(act_i8.as_ptr().cast::<__m256i>());
let q_512 = _mm512_cvtepu8_epi16(q_nibbles_256);
let act_512 = _mm512_cvtepi8_epi16(act_vec);
let prod = _mm512_mullo_epi16(q_512, act_512);
let sum_256 = _mm256_add_epi16(
_mm512_castsi512_si256(prod),
_mm512_extracti64x4_epi64(prod, 1),
);
let sum_128 = _mm_add_epi16(
_mm256_castsi256_si128(sum_256),
_mm256_extracti128_si256(sum_256, 1),
);
let sum_32 = _mm256_cvtepi16_epi32(sum_128);
let sum_arr: [i32; 8] = std::mem::transmute(sum_32);
(sum_arr.iter().sum(), act_scale)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vnni")]
#[allow(unsafe_op_in_unsafe_fn)]
#[allow(dead_code)]
unsafe fn fused_q4k_dot_avx512_vnni(q4k_data: &[u8], activations: &[f32]) -> Result<f32> {
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
const SUPER_BLOCK_BYTES: usize = 144;
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of super-block size {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let expected_values = num_super_blocks * QK_K;
if activations.len() != expected_values {
return Err(RealizarError::InvalidShape {
reason: format!(
"Activation length {} doesn't match Q4_K values count {}",
activations.len(),
expected_values
),
});
}
let mut total_sum = 0.0f32;
let mut activation_idx = 0;
let nibble_mask = _mm512_set1_epi8(0x0F_i8);
for sb_idx in 0..num_super_blocks {
let sb_start = sb_idx * SUPER_BLOCK_BYTES;
if sb_idx + 1 < num_super_blocks {
let next_sb = (sb_idx + 1) * SUPER_BLOCK_BYTES;
_mm_prefetch(q4k_data.as_ptr().add(next_sb).cast::<i8>(), _MM_HINT_T0);
}
let d = read_f16(&q4k_data[sb_start..sb_start + 2]);
let dmin = read_f16(&q4k_data[sb_start + 2..sb_start + 4]);
let mut scales = [0u8; 12];
scales.copy_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
let qs_ptr = q4k_data.as_ptr().add(sb_start + 16);
for j in (0..QK_K).step_by(64) {
let q_start = j / 2;
let is = j / 32;
let (sc1, m1) = extract_scale_min(&scales, is);
let (sc2, m2) = extract_scale_min(&scales, is + 1);
let q_bytes_256 = _mm256_loadu_si256(qs_ptr.add(q_start).cast::<__m256i>());
let q_bytes = _mm512_castsi256_si512(q_bytes_256);
let q_bytes = _mm512_inserti64x4(q_bytes, q_bytes_256, 1);
let q_lo = _mm512_and_si512(q_bytes, nibble_mask);
let q_hi = _mm512_and_si512(_mm512_srli_epi16(q_bytes, 4), nibble_mask);
let q_lo_256 = _mm512_castsi512_si256(q_lo);
let (int_sum, act_scale) =
avx512_quantize_dot(&activations[activation_idx..activation_idx + 32], q_lo_256);
total_sum += int_sum as f32 * d * sc1 / act_scale - (32.0 * dmin * m1);
activation_idx += 32;
let q_hi_256 = _mm512_castsi512_si256(q_hi);
let (int_sum2, act_scale2) =
avx512_quantize_dot(&activations[activation_idx..activation_idx + 32], q_hi_256);
total_sum += int_sum2 as f32 * d * sc2 / act_scale2 - (32.0 * dmin * m2);
activation_idx += 32;
}
}
Ok(total_sum)
}
include!("q4k_dot_avx2.rs");
include!("fused_q4k_q8k_dot_avx512vnni.rs");
include!("horizontal.rs");
include!("requires.rs");
include!("q4_q8_dot_avx2.rs");
#[cfg(target_arch = "x86_64")]
include!("q4_q8_dot_avx512.rs");
#[cfg(target_arch = "x86_64")]
include!("q4k_dot_ggml_style.rs");