use std::{ops::Deref, sync::Arc};
use diskann::ANNResult;
use diskann_utils::object_pool::{self, ObjectPool, PoolOption};
use diskann_vector::PreprocessedDistanceFunction;
use super::common::get_lookup_table_size;
use crate::model::pq::fixed_chunk_pq_table::{FixedChunkPQTable, pq_dist_lookup_single};
#[derive(Debug)]
pub struct TableL2<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
lookup_table: PoolOption<Vec<f32>>,
num_centers: usize,
parent: T,
}
impl<T> TableL2<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
pub(crate) fn new(
parent: T,
query: &[f32],
pool: Option<Arc<ObjectPool<Vec<f32>>>>,
) -> ANNResult<Self> {
let mut object = Self::new_unpopulated(parent, pool);
object.populate(query)?;
Ok(object)
}
fn new_unpopulated(parent: T, pool: Option<Arc<ObjectPool<Vec<f32>>>>) -> Self {
let vec_size = get_lookup_table_size(&parent);
Self {
lookup_table: match pool {
Some(p) => PoolOption::pooled(&p, object_pool::Undef::new(vec_size)),
None => PoolOption::non_pooled_create(object_pool::Undef::new(vec_size)),
},
num_centers: parent.get_num_centers(),
parent,
}
}
fn populate(&mut self, query: &[f32]) -> ANNResult<()> {
self.parent
.populate_chunk_distances(query, &mut self.lookup_table)
}
fn evaluate(&self, code: &[u8]) -> f32 {
let expected = self.parent.get_num_chunks();
assert_eq!(
expected,
code.len(),
"PQ code must have {} entries",
expected
);
pq_dist_lookup_single(code, &self.lookup_table, self.num_centers)
}
}
impl<T> PreprocessedDistanceFunction<&[u8], f32> for TableL2<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
self.evaluate(changing)
}
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use super::{super::test_utils, *};
#[test]
fn test_l2() {
let mut rng = rand::rngs::StdRng::seed_from_u64(5);
for dim in [12, 17, 100, 101] {
for pq_chunks in [1, 17, 19, 20] {
for num_pivots in [10, 127, 256] {
if pq_chunks > dim {
continue;
}
let config = test_utils::TableConfig {
dim,
pq_chunks,
num_pivots,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let num_trials = 10;
let errors = test_utils::RelativeAndAbsolute {
relative: 5e-7,
absolute: 0.0,
};
test_utils::test_l2_inner(
|table: &FixedChunkPQTable, query: &[f32]| {
TableL2::new(table, query, None).unwrap()
},
&table,
num_trials,
config,
&mut rng,
errors,
)
}
}
}
}
#[test]
#[should_panic(expected = "PQ code must have 3 entries")]
fn panic_on_too_long_vector() {
let config = test_utils::TableConfig {
dim: 10,
pq_chunks: 3,
num_pivots: 4,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let query = vec![0.0; config.dim];
let computer = TableL2::new(&table, &query, None).unwrap();
let code = vec![0, 0, 0, 0];
computer.evaluate_similarity(&code);
}
#[test]
#[should_panic(expected = "the len is 4 but the index is 4")]
fn panic_on_out_of_bounds_entry() {
let config = test_utils::TableConfig {
dim: 10,
pq_chunks: 3,
num_pivots: 4,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let query = vec![0.0; config.dim];
let computer = TableL2::new(&table, &query, None).unwrap();
let code = vec![0, 4, 0];
computer.evaluate_similarity(&code);
}
}