use crate::consts::THETA_LOG_DIM;
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "fma,avx")]
#[inline]
pub unsafe fn l2_squared_distance(lhs: &[f32], rhs: &[f32]) -> f32 {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
assert_eq!(lhs.len(), rhs.len());
let mut lhs_ptr = lhs.as_ptr();
let mut rhs_ptr = rhs.as_ptr();
let (mut diff, mut vx, mut vy): (__m256, __m256, __m256);
let mut sum = _mm256_setzero_ps();
for _ in 0..(lhs.len() / 16) {
vx = _mm256_loadu_ps(lhs_ptr);
vy = _mm256_loadu_ps(rhs_ptr);
lhs_ptr = lhs_ptr.add(8);
rhs_ptr = rhs_ptr.add(8);
diff = _mm256_sub_ps(vx, vy);
sum = _mm256_fmadd_ps(diff, diff, sum);
vx = _mm256_loadu_ps(lhs_ptr);
vy = _mm256_loadu_ps(rhs_ptr);
lhs_ptr = lhs_ptr.add(8);
rhs_ptr = rhs_ptr.add(8);
diff = _mm256_sub_ps(vx, vy);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
for _ in 0..((lhs.len() & 0b1111) / 8) {
vx = _mm256_loadu_ps(lhs_ptr);
vy = _mm256_loadu_ps(rhs_ptr);
lhs_ptr = lhs_ptr.add(8);
rhs_ptr = rhs_ptr.add(8);
diff = _mm256_sub_ps(vx, vy);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
#[inline(always)]
unsafe fn reduce_f32_256(accumulate: __m256) -> f32 {
let mut combined = _mm256_add_ps(
accumulate,
_mm256_permute2f128_ps(accumulate, accumulate, 1),
);
combined = _mm256_hadd_ps(combined, combined);
combined = _mm256_hadd_ps(combined, combined);
_mm256_cvtss_f32(combined)
}
let mut res = reduce_f32_256(sum);
for _ in 0..(lhs.len() & 0b111) {
let residual = *lhs_ptr - *rhs_ptr;
res += residual * residual;
lhs_ptr = lhs_ptr.add(1);
rhs_ptr = rhs_ptr.add(1);
}
res
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx,avx2")]
#[inline]
pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) {
use std::arch::x86_64::*;
let length = vec.len();
let mut ptr = vec.as_ptr() as *const __m256i;
for i in (0..length).step_by(32) {
let mut v = _mm256_loadu_si256(ptr);
ptr = ptr.add(1);
v = _mm256_slli_epi32(v, 4);
for j in 0..THETA_LOG_DIM as usize {
let mask = (_mm256_movemask_epi8(v) as u32) as u64;
let shift = i & 32;
binary[(3 - j) * (length >> 6) + (i >> 6)] |= mask << shift;
v = _mm256_slli_epi32(v, 1);
}
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx")]
#[inline]
pub unsafe fn min_max_residual(res: &mut [f32], x: &[f32], y: &[f32]) -> (f32, f32) {
use std::arch::x86_64::*;
let mut min_32x8 = _mm256_set1_ps(f32::MAX);
let mut max_32x8 = _mm256_set1_ps(f32::MIN);
let mut x_ptr = x.as_ptr();
let mut y_ptr = y.as_ptr();
let mut res_ptr = res.as_mut_ptr();
let mut f32x8 = [0.0f32; 8];
let mut min = f32::MAX;
let mut max = f32::MIN;
let length = res.len();
let rest = length & 0b111;
let (mut x256, mut y256, mut res256);
for _ in 0..(length / 8) {
x256 = _mm256_loadu_ps(x_ptr);
y256 = _mm256_loadu_ps(y_ptr);
res256 = _mm256_sub_ps(x256, y256);
_mm256_storeu_ps(res_ptr, res256);
x_ptr = x_ptr.add(8);
y_ptr = y_ptr.add(8);
res_ptr = res_ptr.add(8);
min_32x8 = _mm256_min_ps(min_32x8, res256);
max_32x8 = _mm256_max_ps(max_32x8, res256);
}
_mm256_storeu_ps(f32x8.as_mut_ptr(), min_32x8);
for &x in f32x8.iter() {
if x < min {
min = x;
}
}
_mm256_storeu_ps(f32x8.as_mut_ptr(), max_32x8);
for &x in f32x8.iter() {
if x > max {
max = x;
}
}
for _ in 0..rest {
*res_ptr = *x_ptr - *y_ptr;
if *res_ptr < min {
min = *res_ptr;
}
if *res_ptr > max {
max = *res_ptr;
}
res_ptr = res_ptr.add(1);
x_ptr = x_ptr.add(1);
y_ptr = y_ptr.add(1);
}
(min, max)
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx,avx2")]
#[inline]
pub unsafe fn scalar_quantize(
quantized: &mut [u8],
vec: &[f32],
lower_bound: f32,
multiplier: f32,
) -> u32 {
use std::arch::x86_64::*;
let mut quantize_ptr = quantized.as_mut_ptr() as *mut u64;
let lower = _mm256_set1_ps(lower_bound);
let scalar = _mm256_set1_ps(multiplier);
let mut sum256 = _mm256_setzero_si256();
let mask = _mm256_setr_epi8(
0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
);
let length = vec.len();
let rest = length & 0b111;
let mut vec_ptr = vec.as_ptr();
let mut quantize8xi32;
for _ in 0..(length / 8) {
let v = _mm256_loadu_ps(vec_ptr);
quantize8xi32 = _mm256_cvtps_epi32(_mm256_mul_ps(_mm256_sub_ps(v, lower), scalar));
sum256 = _mm256_add_epi32(sum256, quantize8xi32);
let shuffled = _mm256_shuffle_epi8(quantize8xi32, mask);
quantize_ptr.write(
(_mm256_extract_epi32(shuffled, 0) as u64)
| ((_mm256_extract_epi32(shuffled, 4) as u64) << 32),
);
quantize_ptr = quantize_ptr.add(1);
vec_ptr = vec_ptr.add(8);
}
let mut combined = _mm256_add_epi32(sum256, _mm256_permute2f128_si256(sum256, sum256, 1));
combined = _mm256_hadd_epi32(combined, combined);
combined = _mm256_hadd_epi32(combined, combined);
let mut sum = _mm256_cvtsi256_si32(combined) as u32;
for i in 0..rest {
let q = ((*vec_ptr - lower_bound) * multiplier)
.round()
.to_int_unchecked::<u8>();
quantized[length - rest + i] = q;
sum += q as u32;
vec_ptr = vec_ptr.add(1);
}
sum
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "fma,avx,avx2")]
#[inline]
pub unsafe fn vector_dot_product(lhs: &[f32], rhs: &[f32]) -> f32 {
use std::arch::x86_64::*;
let mut lhs_ptr = lhs.as_ptr();
let mut rhs_ptr = rhs.as_ptr();
let length = lhs.len();
let rest = length & 0b111;
let (mut vx, mut vy): (__m256, __m256);
let mut accumulate = _mm256_setzero_ps();
for _ in 0..(length / 16) {
vx = _mm256_loadu_ps(lhs_ptr);
vy = _mm256_loadu_ps(rhs_ptr);
accumulate = _mm256_fmadd_ps(vx, vy, accumulate);
lhs_ptr = lhs_ptr.add(8);
rhs_ptr = rhs_ptr.add(8);
vx = _mm256_loadu_ps(lhs_ptr);
vy = _mm256_loadu_ps(rhs_ptr);
accumulate = _mm256_fmadd_ps(vx, vy, accumulate);
lhs_ptr = lhs_ptr.add(8);
rhs_ptr = rhs_ptr.add(8);
}
for _ in 0..((length & 0b1111) / 8) {
vx = _mm256_loadu_ps(lhs_ptr);
vy = _mm256_loadu_ps(rhs_ptr);
accumulate = _mm256_fmadd_ps(vx, vy, accumulate);
lhs_ptr = lhs_ptr.add(8);
rhs_ptr = rhs_ptr.add(8);
}
#[inline(always)]
unsafe fn reduce_f32_256(accumulate: __m256) -> f32 {
let mut combined = _mm256_add_ps(
accumulate,
_mm256_permute2f128_ps(accumulate, accumulate, 1),
);
combined = _mm256_hadd_ps(combined, combined);
combined = _mm256_hadd_ps(combined, combined);
_mm256_cvtss_f32(combined)
}
let mut sum = reduce_f32_256(accumulate);
for _ in 0..rest {
sum += *lhs_ptr * *rhs_ptr;
lhs_ptr = lhs_ptr.add(1);
rhs_ptr = rhs_ptr.add(1);
}
sum
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2,avx,avx2")]
#[inline]
pub unsafe fn binary_dot_product(lhs: &[u64], rhs: &[u64]) -> u32 {
use std::arch::x86_64::*;
let mut sum = 0;
let length = lhs.len() / 4;
if length == 0 {
for i in 0..lhs.len() {
sum += (lhs[i] & rhs[i]).count_ones();
}
return sum;
}
let rest = lhs.len() & 0b11;
for i in 0..rest {
sum += (lhs[4 * length + i] & rhs[4 * length + i]).count_ones();
}
#[inline]
unsafe fn mm256_popcnt_epi64(x: __m256i) -> __m256i {
let lookup_table = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, );
let mask = _mm256_set1_epi8(15);
let zero = _mm256_setzero_si256();
let mut low = _mm256_and_si256(x, mask);
let mut high = _mm256_and_si256(_mm256_srli_epi64(x, 4), mask);
low = _mm256_shuffle_epi8(lookup_table, low);
high = _mm256_shuffle_epi8(lookup_table, high);
_mm256_sad_epu8(_mm256_add_epi8(low, high), zero)
}
let mut sum256 = _mm256_setzero_si256();
let mut x_ptr = lhs.as_ptr() as *const __m256i;
let mut y_ptr = rhs.as_ptr() as *const __m256i;
for _ in 0..length {
let x256 = _mm256_loadu_si256(x_ptr);
let y256 = _mm256_loadu_si256(y_ptr);
let and = _mm256_and_si256(x256, y256);
sum256 = _mm256_add_epi64(sum256, mm256_popcnt_epi64(and));
x_ptr = x_ptr.add(1);
y_ptr = y_ptr.add(1);
}
let xa = _mm_add_epi64(
_mm256_castsi256_si128(sum256),
_mm256_extracti128_si256(sum256, 1),
);
sum += _mm_cvtsi128_si64(_mm_add_epi64(xa, _mm_shuffle_epi32(xa, 78))) as u32;
sum
}