use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
use super::super::vectors::{DataRef, MinMaxIP};
use super::meta::MinMaxMeta;
use crate::bits::{Representation, Unsigned};
use crate::distances::{self, UnequalLengths};
use crate::multi_vector::distance::QueryMatRef;
use crate::multi_vector::{Chamfer, MatRef, MaxSim};
pub struct MinMaxKernel;
impl MinMaxKernel {
#[inline(always)]
pub(crate) fn max_sim_kernel<const NBITS: usize, const MBITS: usize, F>(
query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
doc: MatRef<'_, MinMaxMeta<MBITS>>,
mut f: F,
) -> Result<(), UnequalLengths>
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
crate::bits::BitSlice<'x, NBITS, Unsigned>,
crate::bits::BitSlice<'y, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
F: FnMut(usize, f32),
{
for (i, q_ref) in query.rows().enumerate() {
let mut min_distance = f32::MAX;
for d_ref in doc.rows() {
let dist = <MinMaxIP as PureDistanceFunction<
DataRef<'_, NBITS>,
DataRef<'_, MBITS>,
distances::Result<f32>,
>>::evaluate(q_ref, d_ref)?;
min_distance = min_distance.min(dist);
}
f(i, min_distance);
}
Ok(())
}
}
impl<const NBITS: usize, const MBITS: usize>
DistanceFunctionMut<QueryMatRef<'_, MinMaxMeta<NBITS>>, MatRef<'_, MinMaxMeta<MBITS>>>
for MaxSim<'_>
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
crate::bits::BitSlice<'x, NBITS, Unsigned>,
crate::bits::BitSlice<'y, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
#[inline(always)]
fn evaluate(
&mut self,
query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
doc: MatRef<'_, MinMaxMeta<MBITS>>,
) {
assert!(
self.size() == query.num_vectors(),
"scores buffer not right size : {} != {}",
self.size(),
query.num_vectors()
);
let _ = MinMaxKernel::max_sim_kernel(query, doc, |i, score| {
let _ = self.set(i, score);
});
}
}
impl<const NBITS: usize, const MBITS: usize>
PureDistanceFunction<QueryMatRef<'_, MinMaxMeta<NBITS>>, MatRef<'_, MinMaxMeta<MBITS>>, f32>
for Chamfer
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
distances::InnerProduct: for<'a, 'b> PureDistanceFunction<
crate::bits::BitSlice<'a, NBITS, Unsigned>,
crate::bits::BitSlice<'b, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
#[inline(always)]
fn evaluate(
query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
doc: MatRef<'_, MinMaxMeta<MBITS>>,
) -> f32 {
let mut sum = 0.0f32;
let _ = MinMaxKernel::max_sim_kernel(query, doc, |_i, score| {
sum += score;
});
sum
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CompressInto;
use crate::algorithms::Transform;
use crate::algorithms::transforms::NullTransform;
use crate::bits::{Representation, Unsigned};
use crate::minmax::{Data, MinMaxQuantizer};
use crate::multi_vector::{Defaulted, Mat, Standard};
use crate::num::Positive;
use diskann_utils::ReborrowMut;
use std::num::NonZeroUsize;
macro_rules! expand_to_bitrates {
($name:ident, $func:ident) => {
#[test]
fn $name() {
$func::<1, 1>();
$func::<2, 2>();
$func::<4, 4>();
$func::<8, 8>();
$func::<8, 4>();
$func::<8, 2>();
$func::<8, 1>();
}
};
}
const TEST_CASES: &[(usize, usize, usize)] = &[
(1, 1, 4), (1, 5, 8), (5, 1, 8), (3, 4, 16), (7, 7, 32), (2, 3, 128), ];
fn make_quantizer(dim: usize) -> MinMaxQuantizer {
MinMaxQuantizer::new(
Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
Positive::new(1.0).unwrap(),
)
}
fn generate_input_mat(n: usize, dim: usize, offset: usize) -> Vec<f32> {
(0..n * dim)
.map(|idx| {
let i = idx / dim;
let j = idx % dim;
((i + offset) * dim + j) as f32 * 0.1
})
.collect()
}
fn compress_mat<const NBITS: usize>(
quantizer: &MinMaxQuantizer,
input: &[f32],
n: usize,
dim: usize,
) -> Mat<MinMaxMeta<NBITS>>
where
Unsigned: Representation<NBITS>,
{
let input_mat = MatRef::new(Standard::<f32>::new(n, dim).unwrap(), input).unwrap();
let mut output: Mat<MinMaxMeta<NBITS>> =
Mat::new(MinMaxMeta::new(n, dim), Defaulted).unwrap();
quantizer
.compress_into(input_mat, output.reborrow_mut())
.unwrap();
output
}
fn naive_max_sim_single<const NBITS: usize, const MBITS: usize>(
query: DataRef<'_, NBITS>,
doc: &MatRef<'_, MinMaxMeta<MBITS>>,
) -> f32
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
crate::bits::BitSlice<'x, NBITS, Unsigned>,
crate::bits::BitSlice<'y, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
doc.rows()
.map(|d| {
<MinMaxIP as PureDistanceFunction<
DataRef<'_, NBITS>,
DataRef<'_, MBITS>,
distances::Result<f32>,
>>::evaluate(query, d)
.unwrap()
})
.fold(f32::MAX, f32::min)
}
fn test_matches_naive<const NBITS: usize, const MBITS: usize>()
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
crate::bits::BitSlice<'x, NBITS, Unsigned>,
crate::bits::BitSlice<'y, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
for &(nq, nd, dim) in TEST_CASES {
let quantizer = make_quantizer(dim);
let query_data = generate_input_mat(nq, dim, 0);
let doc_data = generate_input_mat(nd, dim, nq);
let query_mat = compress_mat::<NBITS>(&quantizer, &query_data, nq, dim);
let doc_mat = compress_mat::<MBITS>(&quantizer, &doc_data, nd, dim);
let query: QueryMatRef<_> = query_mat.as_view().into();
let doc = doc_mat.as_view();
let expected: Vec<f32> = query
.rows()
.map(|q| naive_max_sim_single(q, &doc))
.collect();
let mut scores = vec![0.0f32; nq];
MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
for (i, (&got, &exp)) in scores.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-5,
"({NBITS},{MBITS}) ({nq},{nd},{dim}) MaxSim[{i}]: {got} != {exp}"
);
}
let mut kernel_scores = vec![0.0f32; nq];
MinMaxKernel::max_sim_kernel(query, doc, |i, s| kernel_scores[i] = s).unwrap();
assert_eq!(
scores, kernel_scores,
"({NBITS},{MBITS}) ({nq},{nd},{dim}) kernel mismatch"
);
let chamfer = Chamfer::evaluate(query, doc);
let sum: f32 = scores.iter().sum();
assert!(
(chamfer - sum).abs() < 1e-4,
"({NBITS},{MBITS}) ({nq},{nd},{dim}) Chamfer {chamfer} != sum {sum}"
);
}
}
expand_to_bitrates!(matches_naive, test_matches_naive);
#[test]
#[should_panic(expected = "scores buffer not right size")]
fn max_sim_panics_on_size_mismatch() {
let dim = 4;
let row_bytes = Data::<8>::canonical_bytes(dim);
let query_data = vec![0u8; 2 * row_bytes];
let doc_data = vec![0u8; 3 * row_bytes];
let query: QueryMatRef<_> = MatRef::new(MinMaxMeta::<8>::new(2, dim), &query_data)
.unwrap()
.into();
let doc = MatRef::new(MinMaxMeta::<8>::new(3, dim), &doc_data).unwrap();
let mut scores = vec![0.0f32; 5]; MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
}
}