use std::ops::Deref;
use diskann::ANNResult;
use diskann_vector::PreprocessedDistanceFunction;
use crate::model::pq::fixed_chunk_pq_table::FixedChunkPQTable;
#[derive(Debug)]
pub struct DirectCosine<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
query: Vec<f32>,
parent: T,
}
impl<T> DirectCosine<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
pub(crate) fn new(parent: T, query: &[f32]) -> ANNResult<Self> {
let mut object = Self::new_unpopulated(parent);
object.populate(query)?;
Ok(object)
}
fn new_unpopulated(parent: T) -> Self {
Self {
query: vec![0.0f32; parent.get_dim()],
parent,
}
}
fn populate(&mut self, query: &[f32]) -> ANNResult<()> {
self.query.copy_from_slice(query);
Ok(())
}
fn evaluate(&self, code: &[u8]) -> f32 {
let expected = self.parent.get_num_chunks();
assert_eq!(
expected,
code.len(),
"PQ code must have {} entries",
expected
);
self.parent.cosine_distance(&(self.query), code)
}
}
impl<T> PreprocessedDistanceFunction<&[u8], f32> for DirectCosine<T>
where
T: Deref<Target = FixedChunkPQTable>,
{
fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
self.evaluate(changing)
}
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use super::{super::test_utils, *};
#[test]
fn test_cosine() {
let mut rng = rand::rngs::StdRng::seed_from_u64(0xc33529acbe474958);
let num_trials = 20;
for dim in [64, 117, 128] {
for pq_chunks in [2, 5, 15] {
for num_pivots in [10, 127, 256] {
if pq_chunks > dim {
continue;
}
let config = test_utils::TableConfig {
dim,
pq_chunks,
num_pivots,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let errors = test_utils::RelativeAndAbsolute {
relative: 2.0e-7,
absolute: 0.0,
};
test_utils::test_cosine_inner(
|table: &FixedChunkPQTable, query: &[f32]| {
DirectCosine::new(table, query).unwrap()
},
&table,
num_trials,
config,
&mut rng,
errors,
);
}
}
}
}
#[test]
#[should_panic(expected = "PQ code must have 3 entries")]
fn panic_on_too_long_vector() {
let config = test_utils::TableConfig {
dim: 10,
pq_chunks: 3,
num_pivots: 4,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let query = vec![0.0; config.dim];
let computer = DirectCosine::new(&table, &query).unwrap();
let code = vec![0, 0, 0, 0];
computer.evaluate_similarity(&code);
}
#[test]
#[should_panic]
fn panic_on_out_of_bounds_entry() {
let config = test_utils::TableConfig {
dim: 10,
pq_chunks: 3,
num_pivots: 4,
start_value: 0.0,
};
let table = test_utils::seed_pivot_table(config);
let query = vec![0.0; config.dim];
let computer = DirectCosine::new(&table, &query).unwrap();
let code = vec![0, 4, 0];
computer.evaluate_similarity(&code);
}
}