use std::{ops::Deref, sync::Arc};
use diskann::ANNResult;
use diskann_utils::object_pool::{self, ObjectPool, PoolOption};
use diskann_vector::PreprocessedDistanceFunction;
use super::common::get_lookup_table_size;
use crate::model::pq::fixed_chunk_pq_table::{FixedChunkPQTable, pq_dist_lookup_single};
#[derive(Debug)]
pub struct TableL2<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
lookup_table: PoolOption<Vec<f32>>,
num_centers: usize,
parent: T,
}
impl<T> TableL2<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
pub(crate) fn new<U>(
parent: T,
query: &[U],
pool: Option<Arc<ObjectPool<Vec<f32>>>>,
) -> ANNResult<Self>
where
U: Into<f32> + Copy,
{
let mut object = Self::new_unpopulated(parent, pool);
object.populate(query)?;
Ok(object)
}
fn new_unpopulated(parent: T, pool: Option<Arc<ObjectPool<Vec<f32>>>>) -> Self {
let vec_size = get_lookup_table_size(&parent);
Self {
lookup_table: match pool {
Some(p) => PoolOption::pooled(&p, object_pool::Undef::new(vec_size)),
None => PoolOption::non_pooled_create(object_pool::Undef::new(vec_size)),
},
num_centers: parent.get_num_centers(),
parent,
}
}
fn populate<U: Into<f32> + Copy>(&mut self, query: &[U]) -> ANNResult<()> {
assert!(self.parent.get_dim() <= query.len());
let mut local_query: Vec<f32> = query.iter().map(|x| (*x).into()).collect();
self.parent.preprocess_query(&mut local_query);
self.parent
.populate_chunk_distances(&local_query, &mut self.lookup_table)
}
fn evaluate(&self, code: &[u8]) -> f32 {
let expected = self.parent.get_num_chunks();
assert_eq!(
expected,
code.len(),
"PQ code must have {} entries",
expected
);
pq_dist_lookup_single(code, &self.lookup_table, self.num_centers)
}
}
impl<T> PreprocessedDistanceFunction<&[u8], f32> for TableL2<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
self.evaluate(changing)
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use diskann_vector::Half;
use rand::SeedableRng;
use rstest::rstest;
use super::{
super::test_utils::{self, TestDistribution},
*,
};
#[rstest]
fn test_l2<T>(
#[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<i8>, PhantomData::<u8>)]
_marker: PhantomData<T>,
) where
T: Into<f32> + TestDistribution,
{
let mut rng = rand::rngs::StdRng::seed_from_u64(5);
for dim in [12, 17, 100, 101] {
for pq_chunks in [1, 17, 19, 20] {
for num_pivots in [10, 127, 256] {
if pq_chunks > dim {
continue;
}
let config = test_utils::TableConfig {
dim,
pq_chunks,
num_pivots,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let num_trials = 10;
let errors = test_utils::RelativeAndAbsolute {
relative: 5e-7,
absolute: 0.0,
};
test_utils::test_l2_inner(
|table: &FixedChunkPQTable, query: &[T]| {
TableL2::new(table, query, None).unwrap()
},
&table,
num_trials,
config,
&mut rng,
errors,
)
}
}
}
}
#[test]
#[should_panic(expected = "PQ code must have 3 entries")]
fn panic_on_too_long_vector() {
let config = test_utils::TableConfig {
dim: 10,
pq_chunks: 3,
num_pivots: 4,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let query = vec![0.0; config.dim];
let computer = TableL2::new(&table, &query, None).unwrap();
let code = vec![0, 0, 0, 0];
computer.evaluate_similarity(&code);
}
#[test]
#[should_panic(expected = "the len is 4 but the index is 4")]
fn panic_on_out_of_bounds_entry() {
let config = test_utils::TableConfig {
dim: 10,
pq_chunks: 3,
num_pivots: 4,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let query = vec![0.0; config.dim];
let computer = TableL2::new(&table, &query, None).unwrap();
let code = vec![0, 4, 0];
computer.evaluate_similarity(&code);
}
}