use std::ops::Deref;
use diskann::ANNResult;
use diskann_vector::PreprocessedDistanceFunction;
use crate::model::pq::fixed_chunk_pq_table::FixedChunkPQTable;
#[derive(Debug)]
pub struct DirectCosine<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
query: Vec<f32>,
parent: T,
}
impl<T> DirectCosine<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
pub(crate) fn new<U>(parent: T, query: &[U]) -> ANNResult<Self>
where
U: Into<f32> + Copy,
{
let mut object = Self::new_unpopulated(parent);
object.populate(query)?;
Ok(object)
}
fn new_unpopulated(parent: T) -> Self {
Self {
query: vec![0.0f32; parent.get_dim()],
parent,
}
}
fn populate<U>(&mut self, query: &[U]) -> ANNResult<()>
where
U: Into<f32> + Copy,
{
assert!(self.query.len() <= query.len());
std::iter::zip(self.query.iter_mut(), query.iter()).for_each(|(dst, src)| {
*dst = (*src).into();
});
Ok(())
}
fn evaluate(&self, code: &[u8]) -> f32 {
let expected = self.parent.get_num_chunks();
assert_eq!(
expected,
code.len(),
"PQ code must have {} entries",
expected
);
self.parent.cosine_distance(&(self.query), code)
}
}
impl<T> PreprocessedDistanceFunction<&[u8], f32> for DirectCosine<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
self.evaluate(changing)
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use diskann_vector::Half;
use rand::SeedableRng;
use rstest::rstest;
use super::{
super::test_utils::{self, TestDistribution},
*,
};
#[rstest]
#[case(PhantomData::<f32>)]
#[case(PhantomData::<Half>)]
#[case(PhantomData::<i8>)]
#[case(PhantomData::<u8>)]
fn test_cosine<T>(#[case] _marker: PhantomData<T>)
where
T: Into<f32> + TestDistribution,
{
let mut rng = rand::rngs::StdRng::seed_from_u64(0xc33529acbe474958);
let num_trials = 20;
for dim in [64, 117, 128] {
for pq_chunks in [2, 5, 15] {
for num_pivots in [10, 127, 256] {
if pq_chunks > dim {
continue;
}
let config = test_utils::TableConfig {
dim,
pq_chunks,
num_pivots,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let errors = test_utils::RelativeAndAbsolute {
relative: 2.0e-7,
absolute: 0.0,
};
test_utils::test_cosine_inner(
|table: &FixedChunkPQTable, query: &[T]| {
DirectCosine::new(table, query).unwrap()
},
&table,
num_trials,
config,
&mut rng,
errors,
);
}
}
}
}
#[test]
#[should_panic(expected = "PQ code must have 3 entries")]
fn panic_on_too_long_vector() {
let config = test_utils::TableConfig {
dim: 10,
pq_chunks: 3,
num_pivots: 4,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let query = vec![0.0; config.dim];
let computer = DirectCosine::new(&table, &query).unwrap();
let code = vec![0, 0, 0, 0];
computer.evaluate_similarity(&code);
}
#[test]
#[should_panic]
fn panic_on_out_of_bounds_entry() {
let config = test_utils::TableConfig {
dim: 10,
pq_chunks: 3,
num_pivots: 4,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let query = vec![0.0; config.dim];
let computer = DirectCosine::new(&table, &query).unwrap();
let code = vec![0, 4, 0];
computer.evaluate_similarity(&code);
}
}