use crate::{utils::binary_search_branchless, DataType};
#[inline]
#[must_use]
pub fn dot_product_dense_sparse<Q, V>(query: &[Q], v_components: &[u16], v_values: &[V]) -> f32
where
Q: DataType,
V: DataType,
{
const N_LANES: usize = 4;
let mut result = [0.0; N_LANES];
let chunk_iter = v_components.iter().zip(v_values).array_chunks::<N_LANES>();
for chunk in chunk_iter {
result[0] += query[*chunk[0].0 as usize].to_f32().unwrap() * (chunk[0].1.to_f32().unwrap());
result[1] += query[*chunk[1].0 as usize].to_f32().unwrap() * chunk[1].1.to_f32().unwrap();
result[2] += query[*chunk[2].0 as usize].to_f32().unwrap() * chunk[2].1.to_f32().unwrap();
result[3] += query[*chunk[3].0 as usize].to_f32().unwrap() * chunk[3].1.to_f32().unwrap();
}
let l = v_components.len();
let rem = l % N_LANES;
if rem > 0 {
for (&i, &v) in v_components[l - rem..].iter().zip(&v_values[l - rem..]) {
result[0] += query[i as usize].to_f32().unwrap() * v.to_f32().unwrap();
}
}
result.iter().sum()
}
#[inline]
#[must_use]
pub fn dot_product_with_binary_search<Q, V>(
query_term_ids: &[u16],
query_values: &[Q],
v_terms_ids: &[u16],
v_values: &[V],
) -> f32
where
Q: DataType,
V: DataType,
{
let mut result = 0.0;
for (&term_id, &value) in query_term_ids.iter().zip(query_values) {
let i = binary_search_branchless(v_terms_ids, term_id);
let cmp = *unsafe { v_terms_ids.get_unchecked(i) } == term_id;
result += if cmp {
value.to_f32().unwrap() * unsafe { v_values.get_unchecked(i).to_f32().unwrap() }
} else {
0.0
};
}
result
}
#[inline]
#[must_use]
pub fn dot_product_with_merge<Q, V>(
query_term_ids: &[u16],
query_values: &[Q],
v_term_ids: &[u16],
v_values: &[V],
) -> f32
where
Q: DataType,
V: DataType,
{
let mut result = 0.0;
let mut i = 0;
for (&q_id, &q_v) in query_term_ids.iter().zip(query_values) {
unsafe {
while i < v_term_ids.len() && *v_term_ids.get_unchecked(i) < q_id {
i += 1;
}
if i == v_term_ids.len() {
break;
}
if *v_term_ids.get_unchecked(i) == q_id {
result += (*v_values.get_unchecked(i)).to_f32().unwrap() * q_v.to_f32().unwrap();
}
}
}
result
}