Skip to main content

diskann_quantization/minmax/multi/
max_sim.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4//! Distance implementations for MinMax quantized multi-vectors.
5
6use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
7
8use super::super::vectors::{DataRef, MinMaxIP};
9use super::meta::MinMaxMeta;
10use crate::bits::{Representation, Unsigned};
11use crate::distances::{self, UnequalLengths};
12use crate::multi_vector::distance::QueryMatRef;
13use crate::multi_vector::{Chamfer, MatRef, MaxSim};
14
15//////////////////
16// MinMaxKernel //
17//////////////////
18
19/// Kernel for computing [`MaxSim`] and [`Chamfer`] distance using MinMax quantized vectors.
20///
21/// Uses a simple double-iteration strategy and computes pairwise inner-products between
22/// query vectors and document vectors using [`MinMaxIP`].
23pub struct MinMaxKernel;
24
25impl MinMaxKernel {
26    /// Core kernel for computing per-query-vector max similarities using MinMax.
27    ///
28    /// For each query vector, computes the maximum similarity (min distance using
29    /// MinMax inner product) to any document vector, then calls `f(index, score)` with the result.
30    ///
31    /// # Arguments
32    ///
33    /// * `query` - The query MinMax multi-vector (wrapped as [`QueryMatRef`])
34    /// * `doc` - The document MinMax multi-vector
35    /// * `f` - Callback invoked with `(query_index, min_distance)` for each query vector
36    #[inline(always)]
37    pub(crate) fn max_sim_kernel<const NBITS: usize, const MBITS: usize, F>(
38        query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
39        doc: MatRef<'_, MinMaxMeta<MBITS>>,
40        mut f: F,
41    ) -> Result<(), UnequalLengths>
42    where
43        Unsigned: Representation<NBITS> + Representation<MBITS>,
44        distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
45                crate::bits::BitSlice<'x, NBITS, Unsigned>,
46                crate::bits::BitSlice<'y, MBITS, Unsigned>,
47                distances::MathematicalResult<u32>,
48            >,
49        F: FnMut(usize, f32),
50    {
51        for (i, q_ref) in query.rows().enumerate() {
52            // Find min distance (IP returns negated, so min = max similarity)
53            let mut min_distance = f32::MAX;
54
55            for d_ref in doc.rows() {
56                // Use MinMaxIP to compute negated inner product as distance
57                let dist = <MinMaxIP as PureDistanceFunction<
58                    DataRef<'_, NBITS>,
59                    DataRef<'_, MBITS>,
60                    distances::Result<f32>,
61                >>::evaluate(q_ref, d_ref)?;
62
63                min_distance = min_distance.min(dist);
64            }
65
66            f(i, min_distance);
67        }
68
69        Ok(())
70    }
71}
72
73////////////
74// MaxSim //
75////////////
76
77impl<const NBITS: usize, const MBITS: usize>
78    DistanceFunctionMut<QueryMatRef<'_, MinMaxMeta<NBITS>>, MatRef<'_, MinMaxMeta<MBITS>>>
79    for MaxSim<'_>
80where
81    Unsigned: Representation<NBITS> + Representation<MBITS>,
82    distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
83            crate::bits::BitSlice<'x, NBITS, Unsigned>,
84            crate::bits::BitSlice<'y, MBITS, Unsigned>,
85            distances::MathematicalResult<u32>,
86        >,
87{
88    #[inline(always)]
89    fn evaluate(
90        &mut self,
91        query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
92        doc: MatRef<'_, MinMaxMeta<MBITS>>,
93    ) {
94        assert!(
95            self.size() == query.num_vectors(),
96            "scores buffer not right size : {} != {}",
97            self.size(),
98            query.num_vectors()
99        );
100
101        let _ = MinMaxKernel::max_sim_kernel(query, doc, |i, score| {
102            // SAFETY: We asserted that self.size() == query.num_vectors(),
103            // and i < query.num_vectors() due to the kernel loop bound.
104            let _ = self.set(i, score);
105        });
106    }
107}
108
109/////////////
110// Chamfer //
111/////////////
112
113impl<const NBITS: usize, const MBITS: usize>
114    PureDistanceFunction<QueryMatRef<'_, MinMaxMeta<NBITS>>, MatRef<'_, MinMaxMeta<MBITS>>, f32>
115    for Chamfer
116where
117    Unsigned: Representation<NBITS> + Representation<MBITS>,
118    distances::InnerProduct: for<'a, 'b> PureDistanceFunction<
119            crate::bits::BitSlice<'a, NBITS, Unsigned>,
120            crate::bits::BitSlice<'b, MBITS, Unsigned>,
121            distances::MathematicalResult<u32>,
122        >,
123{
124    #[inline(always)]
125    fn evaluate(
126        query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
127        doc: MatRef<'_, MinMaxMeta<MBITS>>,
128    ) -> f32 {
129        let mut sum = 0.0f32;
130
131        let _ = MinMaxKernel::max_sim_kernel(query, doc, |_i, score| {
132            sum += score;
133        });
134
135        sum
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::CompressInto;
143    use crate::algorithms::Transform;
144    use crate::algorithms::transforms::NullTransform;
145    use crate::bits::{Representation, Unsigned};
146    use crate::minmax::{Data, MinMaxQuantizer};
147    use crate::multi_vector::{Defaulted, Mat, Standard};
148    use crate::num::Positive;
149    use diskann_utils::ReborrowMut;
150    use std::num::NonZeroUsize;
151
152    macro_rules! expand_to_bitrates {
153        ($name:ident, $func:ident) => {
154            #[test]
155            fn $name() {
156                // Homogeneous
157                $func::<1, 1>();
158                $func::<2, 2>();
159                $func::<4, 4>();
160                $func::<8, 8>();
161                // Heterogeneous
162                $func::<8, 4>();
163                $func::<8, 2>();
164                $func::<8, 1>();
165            }
166        };
167    }
168
169    /// Test cases: (num_queries, num_docs, dim)
170    const TEST_CASES: &[(usize, usize, usize)] = &[
171        (1, 1, 4),   // Single query, single doc
172        (1, 5, 8),   // Single query, multiple docs
173        (5, 1, 8),   // Multiple queries, single doc
174        (3, 4, 16),  // General case
175        (7, 7, 32),  // Square case
176        (2, 3, 128), // Larger dimension
177    ];
178
179    fn make_quantizer(dim: usize) -> MinMaxQuantizer {
180        MinMaxQuantizer::new(
181            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
182            Positive::new(1.0).unwrap(),
183        )
184    }
185
186    fn generate_input_mat(n: usize, dim: usize, offset: usize) -> Vec<f32> {
187        (0..n * dim)
188            .map(|idx| {
189                let i = idx / dim;
190                let j = idx % dim;
191                ((i + offset) * dim + j) as f32 * 0.1
192            })
193            .collect()
194    }
195
196    fn compress_mat<const NBITS: usize>(
197        quantizer: &MinMaxQuantizer,
198        input: &[f32],
199        n: usize,
200        dim: usize,
201    ) -> Mat<MinMaxMeta<NBITS>>
202    where
203        Unsigned: Representation<NBITS>,
204    {
205        let input_mat = MatRef::new(Standard::<f32>::new(n, dim).unwrap(), input).unwrap();
206        let mut output: Mat<MinMaxMeta<NBITS>> =
207            Mat::new(MinMaxMeta::new(n, dim), Defaulted).unwrap();
208        quantizer
209            .compress_into(input_mat, output.reborrow_mut())
210            .unwrap();
211        output
212    }
213
214    /// Naive max-sim for one query vector: min distance to any doc vector.
215    fn naive_max_sim_single<const NBITS: usize, const MBITS: usize>(
216        query: DataRef<'_, NBITS>,
217        doc: &MatRef<'_, MinMaxMeta<MBITS>>,
218    ) -> f32
219    where
220        Unsigned: Representation<NBITS> + Representation<MBITS>,
221        distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
222                crate::bits::BitSlice<'x, NBITS, Unsigned>,
223                crate::bits::BitSlice<'y, MBITS, Unsigned>,
224                distances::MathematicalResult<u32>,
225            >,
226    {
227        doc.rows()
228            .map(|d| {
229                <MinMaxIP as PureDistanceFunction<
230                    DataRef<'_, NBITS>,
231                    DataRef<'_, MBITS>,
232                    distances::Result<f32>,
233                >>::evaluate(query, d)
234                .unwrap()
235            })
236            .fold(f32::MAX, f32::min)
237    }
238
239    fn test_matches_naive<const NBITS: usize, const MBITS: usize>()
240    where
241        Unsigned: Representation<NBITS> + Representation<MBITS>,
242        distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
243                crate::bits::BitSlice<'x, NBITS, Unsigned>,
244                crate::bits::BitSlice<'y, MBITS, Unsigned>,
245                distances::MathematicalResult<u32>,
246            >,
247    {
248        for &(nq, nd, dim) in TEST_CASES {
249            let quantizer = make_quantizer(dim);
250
251            let query_data = generate_input_mat(nq, dim, 0);
252            let doc_data = generate_input_mat(nd, dim, nq);
253
254            let query_mat = compress_mat::<NBITS>(&quantizer, &query_data, nq, dim);
255            let doc_mat = compress_mat::<MBITS>(&quantizer, &doc_data, nd, dim);
256
257            let query: QueryMatRef<_> = query_mat.as_view().into();
258            let doc = doc_mat.as_view();
259
260            // Test MaxSim matches naive
261            let expected: Vec<f32> = query
262                .rows()
263                .map(|q| naive_max_sim_single(q, &doc))
264                .collect();
265
266            let mut scores = vec![0.0f32; nq];
267            MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
268
269            for (i, (&got, &exp)) in scores.iter().zip(expected.iter()).enumerate() {
270                assert!(
271                    (got - exp).abs() < 1e-5,
272                    "({NBITS},{MBITS}) ({nq},{nd},{dim}) MaxSim[{i}]: {got} != {exp}"
273                );
274            }
275
276            // Test kernel matches MaxSim
277            let mut kernel_scores = vec![0.0f32; nq];
278            MinMaxKernel::max_sim_kernel(query, doc, |i, s| kernel_scores[i] = s).unwrap();
279            assert_eq!(
280                scores, kernel_scores,
281                "({NBITS},{MBITS}) ({nq},{nd},{dim}) kernel mismatch"
282            );
283
284            // Test Chamfer equals sum of MaxSim
285            let chamfer = Chamfer::evaluate(query, doc);
286            let sum: f32 = scores.iter().sum();
287            assert!(
288                (chamfer - sum).abs() < 1e-4,
289                "({NBITS},{MBITS}) ({nq},{nd},{dim}) Chamfer {chamfer} != sum {sum}"
290            );
291        }
292    }
293
294    expand_to_bitrates!(matches_naive, test_matches_naive);
295
296    #[test]
297    #[should_panic(expected = "scores buffer not right size")]
298    fn max_sim_panics_on_size_mismatch() {
299        let dim = 4;
300        let row_bytes = Data::<8>::canonical_bytes(dim);
301        let query_data = vec![0u8; 2 * row_bytes];
302        let doc_data = vec![0u8; 3 * row_bytes];
303
304        let query: QueryMatRef<_> = MatRef::new(MinMaxMeta::<8>::new(2, dim), &query_data)
305            .unwrap()
306            .into();
307        let doc = MatRef::new(MinMaxMeta::<8>::new(3, dim), &doc_data).unwrap();
308
309        let mut scores = vec![0.0f32; 5]; // Wrong size
310        MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
311    }
312}