mod f16;
mod f32;
use crate::multi_vector::{BlockTransposed, MatRef, Standard};
#[derive(Debug)]
pub struct QueryComputer<T: Copy> {
inner: Box<dyn DynQueryComputer<T>>,
}
impl<T: Copy> QueryComputer<T> {
#[inline]
pub fn nrows(&self) -> usize {
self.inner.nrows()
}
pub fn chamfer(&self, doc: MatRef<'_, Standard<T>>) -> f32 {
let nq = self.nrows();
if doc.num_vectors() == 0 {
return 0.0;
}
let mut scores = vec![0.0f32; nq];
self.max_sim(doc, &mut scores);
scores.iter().sum()
}
pub fn max_sim(&self, doc: MatRef<'_, Standard<T>>, scores: &mut [f32]) {
let nq = self.nrows();
assert_eq!(
scores.len(),
nq,
"scores buffer not right size: {} != {}",
scores.len(),
nq
);
if doc.num_vectors() == 0 {
return;
}
self.inner.compute_max_sim(doc, scores);
}
}
trait DynQueryComputer<T: Copy>: std::fmt::Debug + Send + Sync {
fn compute_max_sim(&self, doc: MatRef<'_, Standard<T>>, scores: &mut [f32]);
fn nrows(&self) -> usize;
}
#[derive(Debug)]
struct Prepared<A, Q> {
arch: A,
prepared: Q,
}
fn build_prepared<T: Copy + Default, A, const GROUP: usize>(
arch: A,
query: MatRef<'_, Standard<T>>,
) -> Prepared<A, BlockTransposed<T, GROUP>> {
let prepared = BlockTransposed::<T, GROUP>::from_matrix_view(query.as_matrix_view());
Prepared { arch, prepared }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multi_vector::{Chamfer, MaxSim, QueryMatRef};
use diskann_vector::distance::InnerProduct;
use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
trait FromF32 {
fn from_f32(v: f32) -> Self;
}
impl FromF32 for f32 {
fn from_f32(v: f32) -> Self {
v
}
}
impl FromF32 for half::f16 {
fn from_f32(v: f32) -> Self {
diskann_wide::cast_f32_to_f16(v)
}
}
fn make_mat<T: Copy>(data: &[T], nrows: usize, ncols: usize) -> MatRef<'_, Standard<T>> {
MatRef::new(Standard::new(nrows, ncols).unwrap(), data).unwrap()
}
fn make_test_data<T: FromF32>(len: usize, ceil: usize, shift: usize) -> Vec<T> {
(0..len)
.map(|v| T::from_f32(((v + shift) % ceil) as f32))
.collect()
}
const TEST_CASES: &[(usize, usize, usize)] = &[
(1, 1, 4), (5, 3, 5), (17, 4, 64), (16, 6, 32), ];
fn check_chamfer_matches<T: Copy + FromF32>(
build: fn(MatRef<'_, Standard<T>>) -> QueryComputer<T>,
tol: f32,
label: &str,
) where
InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
{
for &(nq, nd, dim) in TEST_CASES {
let query_data = make_test_data::<T>(nq * dim, dim, dim / 2);
let doc_data = make_test_data::<T>(nd * dim, dim, dim);
let query = make_mat(&query_data, nq, dim);
let doc = make_mat(&doc_data, nd, dim);
let expected = Chamfer::evaluate(QueryMatRef::from(query), doc);
let actual = build(query).chamfer(doc);
assert!(
(actual - expected).abs() < tol,
"{label}Chamfer mismatch for ({nq},{nd},{dim}): actual={actual}, expected={expected}",
);
}
}
fn check_max_sim_matches<T: Copy + FromF32>(
build: fn(MatRef<'_, Standard<T>>) -> QueryComputer<T>,
tol: f32,
label: &str,
) where
InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
{
for &(nq, nd, dim) in TEST_CASES {
let query_data = make_test_data::<T>(nq * dim, dim, dim / 2);
let doc_data = make_test_data::<T>(nd * dim, dim, dim);
let query = make_mat(&query_data, nq, dim);
let doc = make_mat(&doc_data, nd, dim);
let mut expected_scores = vec![0.0f32; nq];
let _ = MaxSim::new(&mut expected_scores)
.unwrap()
.evaluate(QueryMatRef::from(query), doc);
let computer = build(query);
let mut actual_scores = vec![0.0f32; nq];
computer.max_sim(doc, &mut actual_scores);
for i in 0..nq {
assert!(
(actual_scores[i] - expected_scores[i]).abs() < tol,
"{label}MaxSim[{i}] mismatch for ({nq},{nd},{dim}): actual={}, expected={}",
actual_scores[i],
expected_scores[i],
);
}
}
}
#[test]
fn query_computer_dimensions() {
let data = vec![1.0f32; 5 * 8];
let query = make_mat(&data, 5, 8);
let computer = QueryComputer::<f32>::new(query);
assert_eq!(computer.nrows(), 5);
}
#[test]
fn query_computer_f16_dimensions() {
let data = vec![diskann_wide::cast_f32_to_f16(1.0); 5 * 8];
let query = make_mat(data.as_slice(), 5, 8);
let computer = QueryComputer::<half::f16>::new(query);
assert_eq!(computer.nrows(), 5);
}
#[test]
fn chamfer_with_zero_docs() {
let query = make_mat(&[1.0f32, 0.0, 0.0, 1.0], 2, 2);
let computer = QueryComputer::<f32>::new(query);
let doc = make_mat(&[], 0, 2);
assert_eq!(computer.chamfer(doc), 0.0);
}
#[test]
fn max_sim_with_zero_docs() {
let query = make_mat(&[1.0f32, 0.0, 0.0, 1.0], 2, 2);
let computer = QueryComputer::<f32>::new(query);
let doc = make_mat::<f32>(&[], 0, 2);
let mut scores = vec![0.0f32; 2];
computer.max_sim(doc, &mut scores);
for &s in &scores {
assert_eq!(s, 0.0, "zero-doc MaxSim should leave scores untouched");
}
}
#[test]
#[should_panic(expected = "scores buffer not right size")]
fn max_sim_panics_on_size_mismatch() {
let query = make_mat(&[1.0f32, 2.0, 3.0, 4.0], 2, 2);
let computer = QueryComputer::<f32>::new(query);
let doc = make_mat(&[1.0, 1.0], 1, 2);
let mut scores = vec![0.0f32; 3]; computer.max_sim(doc, &mut scores);
}
macro_rules! test_matches_fallback {
($mod_name:ident, $ty:ty, $tol:expr, $label:literal) => {
mod $mod_name {
use super::*;
#[test]
fn chamfer_matches_fallback() {
check_chamfer_matches(QueryComputer::<$ty>::new, $tol, $label);
}
#[test]
fn max_sim_matches_fallback() {
check_max_sim_matches(QueryComputer::<$ty>::new, $tol, $label);
}
}
};
}
test_matches_fallback!(f32, f32, 1e-10, "f32 ");
test_matches_fallback!(f16, half::f16, 1e-10, "f16 ");
}