#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn compute_q4_q8_dots_8blocks(
qs_ptr: *const u8,
q8_c0_lo: std::arch::x86_64::__m256i,
q8_c0_hi: std::arch::x86_64::__m256i,
q8_c1_lo: std::arch::x86_64::__m256i,
q8_c1_hi: std::arch::x86_64::__m256i,
q8_c2_lo: std::arch::x86_64::__m256i,
q8_c2_hi: std::arch::x86_64::__m256i,
q8_c3_lo: std::arch::x86_64::__m256i,
q8_c3_hi: std::arch::x86_64::__m256i,
nibble_mask: std::arch::x86_64::__m256i,
ones_16: std::arch::x86_64::__m256i,
) -> std::arch::x86_64::__m256i {
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
unsafe {
let q4_c0 = _mm256_loadu_si256(qs_ptr.cast::<__m256i>());
let q4_c1 = _mm256_loadu_si256(qs_ptr.add(32).cast::<__m256i>());
let q4_c2 = _mm256_loadu_si256(qs_ptr.add(64).cast::<__m256i>());
let q4_c3 = _mm256_loadu_si256(qs_ptr.add(96).cast::<__m256i>());
let dot0 = q4_q8_chunk_dot(q4_c0, q8_c0_lo, q8_c0_hi, nibble_mask, ones_16);
let dot1 = q4_q8_chunk_dot(q4_c1, q8_c1_lo, q8_c1_hi, nibble_mask, ones_16);
let dot2 = q4_q8_chunk_dot(q4_c2, q8_c2_lo, q8_c2_hi, nibble_mask, ones_16);
let dot3 = q4_q8_chunk_dot(q4_c3, q8_c3_lo, q8_c3_hi, nibble_mask, ones_16);
let mut result = _mm256_setzero_si256();
result = _mm256_insert_epi32(result, _mm_cvtsi128_si32(_mm256_castsi256_si128(dot0)), 0);
result = _mm256_insert_epi32(
result,
_mm_extract_epi32(_mm256_castsi256_si128(dot0), 1),
1,
);
result = _mm256_insert_epi32(result, _mm_cvtsi128_si32(_mm256_castsi256_si128(dot1)), 2);
result = _mm256_insert_epi32(
result,
_mm_extract_epi32(_mm256_castsi256_si128(dot1), 1),
3,
);
result = _mm256_insert_epi32(result, _mm_cvtsi128_si32(_mm256_castsi256_si128(dot2)), 4);
result = _mm256_insert_epi32(
result,
_mm_extract_epi32(_mm256_castsi256_si128(dot2), 1),
5,
);
result = _mm256_insert_epi32(result, _mm_cvtsi128_si32(_mm256_castsi256_si128(dot3)), 6);
result = _mm256_insert_epi32(
result,
_mm_extract_epi32(_mm256_castsi256_si128(dot3), 1),
7,
);
result
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn q4_q8_chunk_dot(
q4_packed: std::arch::x86_64::__m256i,
q8_lo: std::arch::x86_64::__m256i,
q8_hi: std::arch::x86_64::__m256i,
nibble_mask: std::arch::x86_64::__m256i,
ones_16: std::arch::x86_64::__m256i,
) -> std::arch::x86_64::__m256i {
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
unsafe {
let q4_lo = _mm256_and_si256(q4_packed, nibble_mask);
let q4_hi = _mm256_and_si256(_mm256_srli_epi16(q4_packed, 4), nibble_mask);
let prod_lo_i16 = _mm256_maddubs_epi16(q4_lo, q8_lo);
let prod_hi_i16 = _mm256_maddubs_epi16(q4_hi, q8_hi);
let prod_lo_i32 = _mm256_madd_epi16(prod_lo_i16, ones_16);
let prod_hi_i32 = _mm256_madd_epi16(prod_hi_i16, ones_16);
let sum_lo = hsum_epi32(prod_lo_i32);
let sum_hi = hsum_epi32(prod_hi_i32);
let mut result = _mm256_setzero_si256();
result = _mm256_insert_epi32(result, sum_lo, 0);
result = _mm256_insert_epi32(result, sum_hi, 1);
result
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hsum_epi32(v: std::arch::x86_64::__m256i) -> i32 {
#[allow(clippy::wildcard_imports)]
use std::arch::x86_64::*;
let sum128 = _mm_add_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
let sum64 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, 0b10_11_00_01));
let sum32 = _mm_add_epi32(sum64, _mm_shuffle_epi32(sum64, 0b00_00_10_10));
_mm_cvtsi128_si32(sum32)
}
#[cfg(test)]
#[path = "fused_k_tests.rs"]
mod fused_k_tests;