use crate::simd_utils::horizontal_sum_256;
use crate::{AsRefItem, DenseVector1D, Float, SparseVector1D, Vector1D};
use itertools::izip;
use std::arch::x86_64::*;
use std::iter::zip;
use half::f16;
#[inline]
pub fn dense_dot_product<T>(query: &[T], document: &[T]) -> f32
where
T: Float,
{
dense_dot_product_unrolled(query, document)
}
#[inline]
pub fn dense_dot_product_unrolled<T>(query: &[T], document: &[T]) -> f32
where
T: Float,
{
const N_LANES: usize = 8;
let mut r = [0.0; N_LANES];
let chunks = query.len() / N_LANES;
for (q_chunk, v_chunk) in zip(query.chunks_exact(N_LANES), document.chunks_exact(N_LANES)) {
for i in 0..N_LANES {
let d = q_chunk[i].to_f32().unwrap() * v_chunk[i].to_f32().unwrap();
r[i] += d;
}
}
r.iter().sum::<f32>()
+ dense_dot_product_general(&query[N_LANES * chunks..], &document[N_LANES * chunks..])
}
#[inline]
pub fn dense_dot_product_general<T>(query: &[T], document: &[T]) -> f32
where
T: Float,
{
query.iter().zip(document).fold(0.0, |acc, (a, b)| {
acc + (a.to_f32().unwrap() * b.to_f32().unwrap())
})
}
#[inline]
pub fn dot_product_dense_sparse<T1, T, U, F>(
query: &DenseVector1D<T1>,
array: &SparseVector1D<U, T>,
) -> f32
where
T1: AsRefItem<Item = F>,
T: AsRefItem<Item = F>,
U: AsRefItem<Item = u16>,
F: Float,
{
const N_LANES: usize = 4;
let mut result = [0.0; N_LANES];
let query_slice = query.values_as_slice();
let chunk_iter = array
.components_as_slice()
.iter()
.zip(array.values_as_slice())
.array_chunks::<N_LANES>();
for chunk in chunk_iter {
result[0] += unsafe {
(*query_slice.get_unchecked(*chunk[0].0 as usize))
.to_f32()
.unwrap()
* (*chunk[0].1).to_f32().unwrap()
};
result[1] += unsafe {
(*query_slice.get_unchecked(*chunk[1].0 as usize))
.to_f32()
.unwrap()
* (*chunk[1].1).to_f32().unwrap()
};
result[2] += unsafe {
(*query_slice.get_unchecked(*chunk[2].0 as usize))
.to_f32()
.unwrap()
* (*chunk[2].1).to_f32().unwrap()
};
result[3] += unsafe {
(*query_slice.get_unchecked(*chunk[3].0 as usize))
.to_f32()
.unwrap()
* (*chunk[3].1).to_f32().unwrap()
};
}
let l = array.components_as_slice().len();
let rem = l % N_LANES;
for (&i, &v) in array.components_as_slice()[l - rem..]
.iter()
.zip(&array.values_as_slice()[l - rem..])
{
result[0] += unsafe {
(*query_slice.get_unchecked(i as usize)).to_f32().unwrap() * v.to_f32().unwrap()
};
}
result.iter().sum()
}
#[inline]
#[must_use]
pub fn sparse_dot_product_with_merge<F, U, T>(
query: &SparseVector1D<U, T>,
vector: &SparseVector1D<U, T>,
) -> f32
where
U: AsRefItem<Item = u16>,
T: AsRefItem<Item = F>,
F: Float,
{
let mut result = 0.0;
let mut i = 0;
for (&q_id, &q_v) in query
.components_as_slice()
.iter()
.zip(query.values_as_slice())
{
unsafe {
while i < vector.components_as_slice().len()
&& *vector.components_as_slice().get_unchecked(i) < q_id
{
i += 1;
}
if i == vector.components_as_slice().len() {
break;
}
if *vector.components_as_slice().get_unchecked(i) == q_id {
result += (*vector.values_as_slice().get_unchecked(i))
.to_f32()
.unwrap()
* q_v.to_f32().unwrap();
}
}
}
result
}
#[inline]
pub fn dot_product_batch_4<T>(query: &[T], values: [&[T]; 4]) -> [f32; 4]
where
T: Float + DotProduct<T>,
{
unsafe { T::dot_product_batch_4(query, values) }
}
#[inline]
pub fn dot_product_unrolled_avx<T>(query: &[T], values: &[T]) -> f32
where
T: Float + DotProduct<T>,
{
unsafe { T::dot_product_unrolled_avx2(query, values) }
}
pub trait DotProduct<U> {
unsafe fn dot_product_unrolled_avx2(query: &[U], values: &[U]) -> f32;
unsafe fn dot_product_batch_4(query: &[U], values: [&[U]; 4]) -> [f32; 4];
}
impl DotProduct<f32> for f32 {
unsafe fn dot_product_unrolled_avx2(query: &[f32], document: &[f32]) -> f32 {
const N_LANES: usize = 8;
let mut sum = _mm256_setzero_ps();
let chunks = query.len() / N_LANES;
for (q_chunk, v_chunk) in query
.chunks_exact(N_LANES)
.zip(document.chunks_exact(N_LANES))
{
let q_vec = _mm256_loadu_ps(q_chunk.as_ptr());
let v_vec = _mm256_loadu_ps(v_chunk.as_ptr());
let prod = _mm256_mul_ps(q_vec, v_vec);
sum = _mm256_add_ps(sum, prod);
}
let mut result = [0.0; N_LANES];
_mm256_storeu_ps(result.as_mut_ptr(), sum);
let simd_sum: f32 = result.iter().sum();
simd_sum
+ query[chunks * N_LANES..]
.iter()
.zip(&document[chunks * N_LANES..])
.fold(0.0, |acc, (&a, &b)| acc + a * b)
}
unsafe fn dot_product_batch_4(query: &[f32], vectors: [&[f32]; 4]) -> [f32; 4] {
const N_LANES: usize = 8;
let mut sum_0 = _mm256_setzero_ps();
let mut sum_1 = _mm256_setzero_ps();
let mut sum_2 = _mm256_setzero_ps();
let mut sum_3 = _mm256_setzero_ps();
for (q_chunk, v0_chunk, v1_chunk, v2_chunk, v3_chunk) in izip!(
query.chunks_exact(N_LANES),
vectors[0].chunks_exact(N_LANES),
vectors[1].chunks_exact(N_LANES),
vectors[2].chunks_exact(N_LANES),
vectors[3].chunks_exact(N_LANES)
) {
let q_values = _mm256_loadu_ps(q_chunk.as_ptr());
let v0_values = _mm256_loadu_ps(v0_chunk.as_ptr());
let v1_values = _mm256_loadu_ps(v1_chunk.as_ptr());
let v2_values = _mm256_loadu_ps(v2_chunk.as_ptr());
let v3_values = _mm256_loadu_ps(v3_chunk.as_ptr());
sum_0 = _mm256_fmadd_ps(q_values, v0_values, sum_0);
sum_1 = _mm256_fmadd_ps(q_values, v1_values, sum_1);
sum_2 = _mm256_fmadd_ps(q_values, v2_values, sum_2);
sum_3 = _mm256_fmadd_ps(q_values, v3_values, sum_3);
}
[
horizontal_sum_256(sum_0),
horizontal_sum_256(sum_1),
horizontal_sum_256(sum_2),
horizontal_sum_256(sum_3),
]
}
}
impl DotProduct<f16> for f16 {
#[target_feature(enable = "avx2,f16c")]
unsafe fn dot_product_unrolled_avx2(query: &[f16], document: &[f16]) -> f32 {
const N_LANES: usize = 8;
let mut sum = _mm256_setzero_ps();
let chunks = query.len() / N_LANES;
for (q_chunk, v_chunk) in query
.chunks_exact(N_LANES)
.zip(document.chunks_exact(N_LANES))
{
let q_half = _mm_loadu_si128(q_chunk.as_ptr() as *const __m128i);
let v_half = _mm_loadu_si128(v_chunk.as_ptr() as *const __m128i);
let q_vec = _mm256_cvtph_ps(q_half);
let v_vec = _mm256_cvtph_ps(v_half);
let prod = _mm256_mul_ps(q_vec, v_vec);
sum = _mm256_add_ps(sum, prod);
}
let mut result = [0.0; N_LANES];
_mm256_storeu_ps(result.as_mut_ptr(), sum);
let simd_sum: f32 = result.iter().sum();
let remainder_start = chunks * N_LANES;
let remainder_sum: f32 = query[remainder_start..]
.iter()
.zip(&document[remainder_start..])
.fold(0.0, |acc, (&q, &v)| acc + q.to_f32() * v.to_f32());
simd_sum + remainder_sum
}
#[target_feature(enable = "avx2,f16c")]
unsafe fn dot_product_batch_4(query: &[f16], vectors: [&[f16]; 4]) -> [f32; 4] {
const N_LANES: usize = 8;
let mut sum_0 = _mm256_setzero_ps();
let mut sum_1 = _mm256_setzero_ps();
let mut sum_2 = _mm256_setzero_ps();
let mut sum_3 = _mm256_setzero_ps();
for (q_chunk, v0_chunk, v1_chunk, v2_chunk, v3_chunk) in izip!(
query.chunks_exact(N_LANES),
vectors[0].chunks_exact(N_LANES),
vectors[1].chunks_exact(N_LANES),
vectors[2].chunks_exact(N_LANES),
vectors[3].chunks_exact(N_LANES)
) {
let q_half = _mm_loadu_si128(q_chunk.as_ptr() as *const __m128i);
let v0_half = _mm_loadu_si128(v0_chunk.as_ptr() as *const __m128i);
let v1_half = _mm_loadu_si128(v1_chunk.as_ptr() as *const __m128i);
let v2_half = _mm_loadu_si128(v2_chunk.as_ptr() as *const __m128i);
let v3_half = _mm_loadu_si128(v3_chunk.as_ptr() as *const __m128i);
let q_values = _mm256_cvtph_ps(q_half);
let v0_values = _mm256_cvtph_ps(v0_half);
let v1_values = _mm256_cvtph_ps(v1_half);
let v2_values = _mm256_cvtph_ps(v2_half);
let v3_values = _mm256_cvtph_ps(v3_half);
sum_0 = _mm256_fmadd_ps(q_values, v0_values, sum_0);
sum_1 = _mm256_fmadd_ps(q_values, v1_values, sum_1);
sum_2 = _mm256_fmadd_ps(q_values, v2_values, sum_2);
sum_3 = _mm256_fmadd_ps(q_values, v3_values, sum_3);
}
[
horizontal_sum_256(sum_0),
horizontal_sum_256(sum_1),
horizontal_sum_256(sum_2),
horizontal_sum_256(sum_3),
]
}
}