Skip to main content

diskann_quantization/multi_vector/distance/query_computer/
mod.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4//! Architecture-opaque query computer with runtime dispatch.
5//!
6//! [`QueryComputer`] wraps a block-transposed query and a captured
7//! architecture token behind a trait-object vtable. CPU detection happens
8//! once at construction; every subsequent distance call goes through
9//! [`Architecture::run3`](diskann_wide::Architecture::run3) with full
10//! `#[target_feature]` propagation — no re-dispatch and no enum matching
11//! on the hot path.
12//!
13//! # Usage
14//!
15//! ```
16//! use diskann_quantization::multi_vector::{
17//!     QueryComputer, MatRef, Standard,
18//! };
19//!
20//! let query_data = [1.0f32, 0.0, 0.0, 1.0];
21//! let doc_data = [1.0f32, 0.0, 0.0, 1.0];
22//!
23//! let query = MatRef::new(Standard::new(2, 2).unwrap(), &query_data).unwrap();
24//! let doc = MatRef::new(Standard::new(2, 2).unwrap(), &doc_data).unwrap();
25//!
26//! // Build — runtime detects arch, picks optimal GROUP, captures both
27//! let computer = QueryComputer::<f32>::new(query);
28//!
29//! // Distance — vtable → arch.run3 with target_feature propagation
30//! let dist = computer.chamfer(doc);
31//! assert_eq!(dist, -2.0);
32//! ```
33
34mod f16;
35mod f32;
36
37use crate::multi_vector::{BlockTransposed, MatRef, Standard};
38
39/// Architecture-dispatched query computer for multi-vector distance.
40#[derive(Debug)]
41pub struct QueryComputer<T: Copy> {
42    inner: Box<dyn DynQueryComputer<T>>,
43}
44
45impl<T: Copy> QueryComputer<T> {
46    /// Number of logical (non-padded) query vectors.
47    #[inline]
48    pub fn nrows(&self) -> usize {
49        self.inner.nrows()
50    }
51
52    /// Compute Chamfer distance (sum of per-query max similarities, negated).
53    ///
54    /// Returns `0.0` if the document has zero vectors.
55    pub fn chamfer(&self, doc: MatRef<'_, Standard<T>>) -> f32 {
56        let nq = self.nrows();
57        if doc.num_vectors() == 0 {
58            return 0.0;
59        }
60        let mut scores = vec![0.0f32; nq];
61        self.max_sim(doc, &mut scores);
62        scores.iter().sum()
63    }
64
65    /// Compute per-query-vector max similarities into `scores`.
66    ///
67    /// `scores` must have length equal to [`nrows()`](Self::nrows).
68    /// Each entry is the negated max inner product for that query vector.
69    ///
70    /// # Panics
71    ///
72    /// Panics if `scores.len() != self.nrows()`.
73    pub fn max_sim(&self, doc: MatRef<'_, Standard<T>>, scores: &mut [f32]) {
74        let nq = self.nrows();
75        assert_eq!(
76            scores.len(),
77            nq,
78            "scores buffer not right size: {} != {}",
79            scores.len(),
80            nq
81        );
82
83        if doc.num_vectors() == 0 {
84            return;
85        }
86
87        self.inner.compute_max_sim(doc, scores);
88    }
89}
90
91trait DynQueryComputer<T: Copy>: std::fmt::Debug + Send + Sync {
92    fn compute_max_sim(&self, doc: MatRef<'_, Standard<T>>, scores: &mut [f32]);
93    fn nrows(&self) -> usize;
94}
95
96#[derive(Debug)]
97struct Prepared<A, Q> {
98    arch: A,
99    prepared: Q,
100}
101
102fn build_prepared<T: Copy + Default, A, const GROUP: usize>(
103    arch: A,
104    query: MatRef<'_, Standard<T>>,
105) -> Prepared<A, BlockTransposed<T, GROUP>> {
106    let prepared = BlockTransposed::<T, GROUP>::from_matrix_view(query.as_matrix_view());
107    Prepared { arch, prepared }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::multi_vector::{Chamfer, MaxSim, QueryMatRef};
114    use diskann_vector::distance::InnerProduct;
115    use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
116
117    trait FromF32 {
118        fn from_f32(v: f32) -> Self;
119    }
120
121    impl FromF32 for f32 {
122        fn from_f32(v: f32) -> Self {
123            v
124        }
125    }
126
127    impl FromF32 for half::f16 {
128        fn from_f32(v: f32) -> Self {
129            diskann_wide::cast_f32_to_f16(v)
130        }
131    }
132
133    fn make_mat<T: Copy>(data: &[T], nrows: usize, ncols: usize) -> MatRef<'_, Standard<T>> {
134        MatRef::new(Standard::new(nrows, ncols).unwrap(), data).unwrap()
135    }
136
137    fn make_test_data<T: FromF32>(len: usize, ceil: usize, shift: usize) -> Vec<T> {
138        (0..len)
139            .map(|v| T::from_f32(((v + shift) % ceil) as f32))
140            .collect()
141    }
142
143    /// Shapes for the `chamfer_matches_fallback` / `max_sim_matches_fallback`
144    /// agreement checks: (num_queries, num_docs, dim).
145    ///
146    /// This matrix targets the API-layer wiring that lives above the
147    /// kernel — `QueryComputer::new` query setup, `chamfer` row
148    /// summation, `max_sim` per-row writeback, and the f16 query
149    /// conversion path — not kernel correctness. A small
150    /// representative set is sufficient because exhaustive shape
151    /// coverage (panel boundaries, B-remainder classes, prime `k`,
152    /// degenerate dims) is pinned one layer below in
153    /// `kernels::tiled_reduce::tests::NAIVE_CASES`, and structural
154    /// loop-path coverage in `tiled_reduce_all_loop_paths_match_naive`.
155    const TEST_CASES: &[(usize, usize, usize)] = &[
156        (1, 1, 4), // Degenerate
157        (5, 3, 5), // Prime k; nq > 1 and nd > 1 exercise chamfer summation
158        //              and per-row max_sim writeback on a non-trivial shape
159        (17, 4, 64), // A-panel remainder crossing both Scalar and V3 panel widths
160        (16, 6, 32), // B-remainder ≠ 1 (V3 b_remainder = 2)
161    ];
162
163    fn check_chamfer_matches<T: Copy + FromF32>(
164        build: fn(MatRef<'_, Standard<T>>) -> QueryComputer<T>,
165        tol: f32,
166        label: &str,
167    ) where
168        InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
169    {
170        for &(nq, nd, dim) in TEST_CASES {
171            let query_data = make_test_data::<T>(nq * dim, dim, dim / 2);
172            let doc_data = make_test_data::<T>(nd * dim, dim, dim);
173
174            let query = make_mat(&query_data, nq, dim);
175            let doc = make_mat(&doc_data, nd, dim);
176
177            let expected = Chamfer::evaluate(QueryMatRef::from(query), doc);
178            let actual = build(query).chamfer(doc);
179
180            assert!(
181                (actual - expected).abs() < tol,
182                "{label}Chamfer mismatch for ({nq},{nd},{dim}): actual={actual}, expected={expected}",
183            );
184        }
185    }
186
187    fn check_max_sim_matches<T: Copy + FromF32>(
188        build: fn(MatRef<'_, Standard<T>>) -> QueryComputer<T>,
189        tol: f32,
190        label: &str,
191    ) where
192        InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
193    {
194        for &(nq, nd, dim) in TEST_CASES {
195            let query_data = make_test_data::<T>(nq * dim, dim, dim / 2);
196            let doc_data = make_test_data::<T>(nd * dim, dim, dim);
197
198            let query = make_mat(&query_data, nq, dim);
199            let doc = make_mat(&doc_data, nd, dim);
200
201            let mut expected_scores = vec![0.0f32; nq];
202            let _ = MaxSim::new(&mut expected_scores)
203                .unwrap()
204                .evaluate(QueryMatRef::from(query), doc);
205
206            let computer = build(query);
207            let mut actual_scores = vec![0.0f32; nq];
208            computer.max_sim(doc, &mut actual_scores);
209
210            for i in 0..nq {
211                assert!(
212                    (actual_scores[i] - expected_scores[i]).abs() < tol,
213                    "{label}MaxSim[{i}] mismatch for ({nq},{nd},{dim}): actual={}, expected={}",
214                    actual_scores[i],
215                    expected_scores[i],
216                );
217            }
218        }
219    }
220
221    #[test]
222    fn query_computer_dimensions() {
223        let data = vec![1.0f32; 5 * 8];
224        let query = make_mat(&data, 5, 8);
225        let computer = QueryComputer::<f32>::new(query);
226
227        assert_eq!(computer.nrows(), 5);
228    }
229
230    #[test]
231    fn query_computer_f16_dimensions() {
232        let data = vec![diskann_wide::cast_f32_to_f16(1.0); 5 * 8];
233        let query = make_mat(data.as_slice(), 5, 8);
234        let computer = QueryComputer::<half::f16>::new(query);
235
236        assert_eq!(computer.nrows(), 5);
237    }
238
239    #[test]
240    fn chamfer_with_zero_docs() {
241        let query = make_mat(&[1.0f32, 0.0, 0.0, 1.0], 2, 2);
242        let computer = QueryComputer::<f32>::new(query);
243        let doc = make_mat(&[], 0, 2);
244        assert_eq!(computer.chamfer(doc), 0.0);
245    }
246
247    #[test]
248    fn max_sim_with_zero_docs() {
249        let query = make_mat(&[1.0f32, 0.0, 0.0, 1.0], 2, 2);
250        let computer = QueryComputer::<f32>::new(query);
251        let doc = make_mat::<f32>(&[], 0, 2);
252        let mut scores = vec![0.0f32; 2];
253        computer.max_sim(doc, &mut scores);
254        // With zero docs the scores buffer is left untouched.
255        for &s in &scores {
256            assert_eq!(s, 0.0, "zero-doc MaxSim should leave scores untouched");
257        }
258    }
259
260    #[test]
261    #[should_panic(expected = "scores buffer not right size")]
262    fn max_sim_panics_on_size_mismatch() {
263        let query = make_mat(&[1.0f32, 2.0, 3.0, 4.0], 2, 2);
264        let computer = QueryComputer::<f32>::new(query);
265        let doc = make_mat(&[1.0, 1.0], 1, 2);
266        let mut scores = vec![0.0f32; 3]; // Wrong size
267        computer.max_sim(doc, &mut scores);
268    }
269
270    macro_rules! test_matches_fallback {
271        ($mod_name:ident, $ty:ty, $tol:expr, $label:literal) => {
272            mod $mod_name {
273                use super::*;
274
275                #[test]
276                fn chamfer_matches_fallback() {
277                    check_chamfer_matches(QueryComputer::<$ty>::new, $tol, $label);
278                }
279
280                #[test]
281                fn max_sim_matches_fallback() {
282                    check_max_sim_matches(QueryComputer::<$ty>::new, $tol, $label);
283                }
284            }
285        };
286    }
287
288    test_matches_fallback!(f32, f32, 1e-10, "f32 ");
289    test_matches_fallback!(f16, half::f16, 1e-10, "f16 ");
290}