use core::panic;
use std::cmp::{max, min};
use super::{num_centroids, utils::get_sub_vector_centroids};
use lance_core::assume_eq;
use lance_linalg::distance::{Dot, L2, dot_distance_batch, l2::L2Prepared, l2_distance_batch};
use lance_linalg::simd::u8::u8x16;
use lance_linalg::simd::{SIMD, Shuffle};
const FLAT_NUM_4BIT_PQ: usize = 200;
pub fn build_distance_table_l2<T: L2>(
codebook: &[T],
num_bits: u32,
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
match num_bits {
4 => build_distance_table_l2_impl::<4, T>(codebook, num_sub_vectors, query),
8 => build_distance_table_l2_impl::<8, T>(codebook, num_sub_vectors, query),
_ => panic!("Unsupported number of bits: {}", num_bits),
}
}
#[inline]
pub fn build_distance_table_l2_impl<const NUM_BITS: u32, T: L2>(
codebook: &[T],
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
let num_centroids = 2_usize.pow(NUM_BITS);
let mut result = Vec::with_capacity(num_sub_vectors * num_centroids);
for (i, sub_vec) in query.chunks_exact(sub_vector_length).enumerate() {
let subvec_centroids =
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
result.extend(l2_distance_batch(
sub_vec,
subvec_centroids,
sub_vector_length,
));
}
result
}
pub fn build_distance_table_l2_prepared(l2_targets: &[L2Prepared], query: &[f32]) -> Vec<f32> {
let sub_dim = query.len() / l2_targets.len();
let num_targets = l2_targets[0].num_targets();
let mut result = vec![0.0f32; l2_targets.len() * num_targets];
for (i, sub_vec) in query.chunks_exact(sub_dim).enumerate() {
l2_targets[i].distances_into(sub_vec, &mut result[i * num_targets..][..num_targets]);
}
result
}
pub fn build_distance_table_dot<T: Dot>(
codebook: &[T],
num_bits: u32,
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
match num_bits {
4 => build_distance_table_dot_impl::<4, T>(codebook, num_sub_vectors, query),
8 => build_distance_table_dot_impl::<8, T>(codebook, num_sub_vectors, query),
_ => panic!("Unsupported number of bits: {}", num_bits),
}
}
#[inline]
pub fn build_distance_table_dot_impl<const NUM_BITS: u32, T: Dot>(
codebook: &[T],
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
let num_centroids = 2_usize.pow(NUM_BITS);
let mut result = Vec::with_capacity(num_sub_vectors * num_centroids);
for (i, sub_vec) in query.chunks_exact(sub_vector_length).enumerate() {
let subvec_centroids =
get_sub_vector_centroids::<NUM_BITS, _>(codebook, dimension, num_sub_vectors, i);
result.extend(dot_distance_batch(
sub_vec,
subvec_centroids,
sub_vector_length,
));
}
result
}
#[inline]
pub(super) fn compute_pq_distance(
distance_table: &[f32],
num_bits: u32,
num_sub_vectors: usize,
code: &[u8],
k_hint: usize,
) -> Vec<f32> {
if code.is_empty() {
return Vec::new();
}
if num_bits == 4 {
return compute_pq_distance_4bit(distance_table, num_sub_vectors, code, k_hint);
}
let num_vectors = code.len() / num_sub_vectors;
let mut distances = vec![0.0; num_vectors];
const NUM_CENTROIDS: usize = 2_usize.pow(8);
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
let dist_table =
&distance_table[sub_vec_idx * NUM_CENTROIDS..(sub_vec_idx + 1) * NUM_CENTROIDS];
assume_eq!(dist_table.len(), NUM_CENTROIDS);
assume_eq!(vec_indices.len(), distances.len());
vec_indices
.iter()
.zip(distances.iter_mut())
.for_each(|(¢roid_idx, sum)| {
*sum += dist_table[centroid_idx as usize];
});
}
distances
}
#[inline]
pub(super) fn compute_pq_distance_4bit(
distance_table: &[f32],
num_sub_vectors: usize,
code: &[u8],
k_hint: usize,
) -> Vec<f32> {
let num_vectors = code.len() * 2 / num_sub_vectors;
let mut distances = vec![0.0f32; num_vectors];
let k_hint = min(k_hint, num_vectors);
let flat_num = max(FLAT_NUM_4BIT_PQ, k_hint).min(num_vectors);
compute_pq_distance_4bit_flat(
distance_table,
num_vectors,
code,
0,
flat_num,
&mut distances,
);
let qmax = *distances
.iter()
.take(flat_num)
.max_by(|a, b| a.total_cmp(b))
.unwrap();
let (qmin, quantized_dists_table) = quantize_distance_table(distance_table, qmax);
const NUM_CENTROIDS: usize = 2_usize.pow(4);
let mut quantized_dists = vec![0_u8; num_vectors];
let remainder = num_vectors % NUM_CENTROIDS;
for i in (0..num_vectors - remainder).step_by(NUM_CENTROIDS) {
let mut block_distances = u8x16::zeros();
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
let origin_dist_table = unsafe {
u8x16::load_unaligned(
quantized_dists_table
.as_ptr()
.add(sub_vec_idx * 2 * NUM_CENTROIDS),
)
};
let origin_next_dist_table = unsafe {
u8x16::load_unaligned(
quantized_dists_table
.as_ptr()
.add((sub_vec_idx * 2 + 1) * NUM_CENTROIDS),
)
};
let indices = unsafe { u8x16::load_unaligned(vec_indices.as_ptr().add(i)) };
let current_indices = indices.bit_and(0x0F);
block_distances += origin_dist_table.shuffle(current_indices);
let next_indices = indices.right_shift::<4>();
block_distances += origin_next_dist_table.shuffle(next_indices);
}
unsafe {
block_distances.store_unaligned(quantized_dists.as_mut_ptr().add(i));
}
}
if remainder > 0 {
let offset = max(num_vectors - remainder, flat_num);
compute_pq_distance_4bit_flat(
distance_table,
num_vectors,
code,
offset,
num_vectors - offset,
&mut distances,
);
}
let range = (qmax - qmin) / 255.0;
distances
.iter_mut()
.take(num_vectors - remainder) .skip(flat_num) .zip(
quantized_dists
.into_iter()
.take(num_vectors - remainder)
.skip(flat_num),
)
.for_each(|(dist, q_dist)| {
*dist = (q_dist as f32) * range + qmin;
});
distances
}
fn compute_pq_distance_4bit_flat(
distance_table: &[f32],
num_vectors: usize,
code: &[u8],
offset: usize,
length: usize,
dists: &mut [f32],
) {
const NUM_CENTROIDS: usize = 2_usize.pow(4);
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
let vec_indices = &vec_indices[offset..offset + length];
let distances = &mut dists[offset..offset + length];
let dist_table = &distance_table[sub_vec_idx * 2 * NUM_CENTROIDS..];
let next_dist_table = &distance_table[(sub_vec_idx * 2 + 1) * NUM_CENTROIDS..];
for (i, ¢roid_idx) in vec_indices.iter().enumerate() {
let current_idx = centroid_idx & 0xF;
let next_idx = centroid_idx >> 4;
distances[i] += dist_table[current_idx as usize];
distances[i] += next_dist_table[next_idx as usize];
}
}
}
#[inline]
fn quantize_distance_table(distance_table: &[f32], qmax: f32) -> (f32, Vec<u8>) {
let qmin = distance_table.iter().cloned().fold(f32::INFINITY, f32::min);
let factor = 255.0 / (qmax - qmin);
let quantized_dist_table = distance_table
.iter()
.map(|&d| ((d - qmin) * factor).round() as u8)
.collect();
(qmin, quantized_dist_table)
}
#[allow(dead_code)]
fn compute_l2_distance_without_transposing<const C: usize, const V: usize>(
distance_table: &[f32],
num_bits: u32,
num_sub_vectors: usize,
code: &[u8],
) -> Vec<f32> {
let num_centroids = num_centroids(num_bits);
let iter = code.chunks_exact(num_sub_vectors * V);
let distances = iter.clone().flat_map(|c| {
let mut sums = [0.0_f32; V];
for i in (0..num_sub_vectors).step_by(C) {
for (vec_idx, sum) in sums.iter_mut().enumerate() {
let vec_start = vec_idx * num_sub_vectors;
let s = c[vec_start + i..]
.iter()
.take(min(C, num_sub_vectors - i))
.enumerate()
.map(|(k, c)| distance_table[(i + k) * num_centroids + *c as usize])
.sum::<f32>();
*sum += s;
}
}
sums.into_iter()
});
let remainder = iter.remainder().chunks(num_sub_vectors).map(|c| {
c.iter()
.enumerate()
.map(|(sub_vec_idx, code)| distance_table[sub_vec_idx * num_centroids + *code as usize])
.sum::<f32>()
});
distances.chain(remainder).collect()
}
#[cfg(test)]
mod tests {
use crate::vector::pq::storage::transpose;
use super::*;
use arrow_array::UInt8Array;
#[test]
fn test_compute_on_transposed_codes() {
let num_vectors = 100;
let num_sub_vectors = 4;
let num_bits = 8;
let dimension = 16;
let codebook =
Vec::from_iter((0..num_sub_vectors * num_vectors * dimension).map(|v| v as f32));
let query = Vec::from_iter((0..dimension).map(|v| v as f32));
let distance_table = build_distance_table_l2(&codebook, num_bits, num_sub_vectors, &query);
let pq_codes = Vec::from_iter((0..num_vectors * num_sub_vectors).map(|v| v as u8));
let pq_codes = UInt8Array::from_iter_values(pq_codes);
let transposed_codes = transpose(&pq_codes, num_vectors, num_sub_vectors);
let distances = compute_pq_distance(
&distance_table,
num_bits,
num_sub_vectors,
transposed_codes.values(),
100,
);
let expected = compute_l2_distance_without_transposing::<4, 1>(
&distance_table,
num_bits,
num_sub_vectors,
pq_codes.values(),
);
assert_eq!(distances, expected);
}
}