#![allow(dead_code)]
use super::columnar_vectors::PDX_BLOCK_SIZE;
macro_rules! debug_assert_pdx_inputs {
($query:expr, $block:expr, $dimension:expr, $block_size:expr $(,)?) => {{
debug_assert_eq!($query.len(), $dimension);
debug_assert_eq!($block.len(), PDX_BLOCK_SIZE * $dimension);
debug_assert!($block_size <= PDX_BLOCK_SIZE);
}};
}
#[must_use]
pub(crate) fn block_squared_l2(
query: &[f32],
block: &[f32],
dimension: usize,
block_size: usize,
) -> [f32; PDX_BLOCK_SIZE] {
debug_assert_pdx_inputs!(query, block, dimension, block_size);
let mut acc = [0.0_f32; PDX_BLOCK_SIZE];
accumulate_squared_diff(&mut acc, query, block, dimension);
zero_padding_slots(&mut acc, block_size);
acc
}
#[must_use]
pub(crate) fn block_dot_product(
query: &[f32],
block: &[f32],
dimension: usize,
block_size: usize,
) -> [f32; PDX_BLOCK_SIZE] {
debug_assert_pdx_inputs!(query, block, dimension, block_size);
let mut acc = [0.0_f32; PDX_BLOCK_SIZE];
accumulate_products(&mut acc, query, block, dimension);
negate_and_zero_padding(&mut acc, block_size);
acc
}
#[must_use]
pub(crate) fn block_cosine_distance(
query: &[f32],
block: &[f32],
dimension: usize,
block_size: usize,
) -> [f32; PDX_BLOCK_SIZE] {
debug_assert_pdx_inputs!(query, block, dimension, block_size);
let (dot, norm_b_sq) = accumulate_dot_and_norm(query, block, dimension);
let query_norm_sq = query_norm_squared(query);
finalize_cosine_distances(dot, norm_b_sq, query_norm_sq, block_size)
}
#[allow(clippy::needless_range_loop)]
#[inline]
fn accumulate_squared_diff(
acc: &mut [f32; PDX_BLOCK_SIZE],
query: &[f32],
block: &[f32],
dimension: usize,
) {
for d in 0..dimension {
let q_d = query[d];
let base = d * PDX_BLOCK_SIZE;
for v in 0..PDX_BLOCK_SIZE {
let diff = q_d - block[base + v];
acc[v] += diff * diff;
}
}
}
#[allow(clippy::needless_range_loop)]
#[inline]
fn accumulate_products(
acc: &mut [f32; PDX_BLOCK_SIZE],
query: &[f32],
block: &[f32],
dimension: usize,
) {
for d in 0..dimension {
let q_d = query[d];
let base = d * PDX_BLOCK_SIZE;
for v in 0..PDX_BLOCK_SIZE {
acc[v] += q_d * block[base + v];
}
}
}
#[inline]
fn negate_and_zero_padding(acc: &mut [f32; PDX_BLOCK_SIZE], block_size: usize) {
for item in acc.iter_mut().take(block_size) {
*item = -*item;
}
for item in acc.iter_mut().skip(block_size) {
*item = 0.0;
}
}
#[inline]
fn zero_padding_slots(acc: &mut [f32; PDX_BLOCK_SIZE], block_size: usize) {
for item in acc.iter_mut().skip(block_size) {
*item = 0.0;
}
}
#[allow(clippy::needless_range_loop)]
#[inline]
fn accumulate_dot_and_norm(
query: &[f32],
block: &[f32],
dimension: usize,
) -> ([f32; PDX_BLOCK_SIZE], [f32; PDX_BLOCK_SIZE]) {
let mut dot = [0.0_f32; PDX_BLOCK_SIZE];
let mut norm_b_sq = [0.0_f32; PDX_BLOCK_SIZE];
for d in 0..dimension {
let q_d = query[d];
let base = d * PDX_BLOCK_SIZE;
for v in 0..PDX_BLOCK_SIZE {
let b_val = block[base + v];
dot[v] += q_d * b_val;
norm_b_sq[v] += b_val * b_val;
}
}
(dot, norm_b_sq)
}
#[inline]
fn query_norm_squared(query: &[f32]) -> f32 {
query.iter().map(|x| x * x).sum()
}
#[inline]
fn finalize_cosine_distances(
dot: [f32; PDX_BLOCK_SIZE],
norm_b_sq: [f32; PDX_BLOCK_SIZE],
query_norm_sq: f32,
block_size: usize,
) -> [f32; PDX_BLOCK_SIZE] {
let mut result = [0.0_f32; PDX_BLOCK_SIZE];
for (v, result_v) in result.iter_mut().enumerate().take(block_size) {
let denom_sq = query_norm_sq * norm_b_sq[v];
if denom_sq < f32::EPSILON * f32::EPSILON {
*result_v = 1.0; } else {
let sim = (dot[v] / denom_sq.sqrt()).clamp(-1.0, 1.0);
*result_v = 1.0 - sim;
}
}
result
}