Skip to main content

diskann_quantization/minmax/multi/
max_sim.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4//! Distance implementations for MinMax quantized multi-vectors.
5
6use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
7
8use super::super::vectors::{DataRef, MinMaxIP};
9use super::meta::MinMaxMeta;
10use crate::bits::{Representation, Unsigned};
11use crate::distances::{self, UnequalLengths};
12use crate::multi_vector::distance::QueryMatRef;
13use crate::multi_vector::{Chamfer, MatRef, MaxSim};
14
15//////////////////
16// MinMaxKernel //
17//////////////////
18
19/// Kernel for computing [`MaxSim`] and [`Chamfer`] distance using MinMax quantized vectors.
20///
21/// Uses a simple double-iteration strategy and computes pairwise inner-products between
22/// query vectors and document vectors using [`MinMaxIP`].
23pub struct MinMaxKernel;
24
25impl MinMaxKernel {
26    /// Core kernel for computing per-query-vector max similarities using MinMax.
27    ///
28    /// For each query vector, computes the maximum similarity (min distance using
29    /// MinMax inner product) to any document vector, then calls `f(index, score)` with the result.
30    ///
31    /// # Arguments
32    ///
33    /// * `query` - The query MinMax multi-vector (wrapped as [`QueryMatRef`])
34    /// * `doc` - The document MinMax multi-vector
35    /// * `f` - Callback invoked with `(query_index, min_distance)` for each query vector
36    #[inline(always)]
37    pub(crate) fn max_sim_kernel<const NBITS: usize, F>(
38        query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
39        doc: MatRef<'_, MinMaxMeta<NBITS>>,
40        mut f: F,
41    ) -> Result<(), UnequalLengths>
42    where
43        Unsigned: Representation<NBITS>,
44        distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
45                crate::bits::BitSlice<'x, NBITS, Unsigned>,
46                crate::bits::BitSlice<'y, NBITS, Unsigned>,
47                distances::MathematicalResult<u32>,
48            >,
49        F: FnMut(usize, f32),
50    {
51        for (i, q_ref) in query.rows().enumerate() {
52            // Find min distance (IP returns negated, so min = max similarity)
53            let mut min_distance = f32::MAX;
54
55            for d_ref in doc.rows() {
56                // Use MinMaxIP to compute negated inner product as distance
57                let dist = <MinMaxIP as PureDistanceFunction<
58                    DataRef<'_, NBITS>,
59                    DataRef<'_, NBITS>,
60                    distances::Result<f32>,
61                >>::evaluate(q_ref, d_ref)?;
62
63                min_distance = min_distance.min(dist);
64            }
65
66            f(i, min_distance);
67        }
68
69        Ok(())
70    }
71}
72
73////////////
74// MaxSim //
75////////////
76
77impl<const NBITS: usize>
78    DistanceFunctionMut<QueryMatRef<'_, MinMaxMeta<NBITS>>, MatRef<'_, MinMaxMeta<NBITS>>>
79    for MaxSim<'_>
80where
81    Unsigned: Representation<NBITS>,
82    distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
83            crate::bits::BitSlice<'x, NBITS, Unsigned>,
84            crate::bits::BitSlice<'y, NBITS, Unsigned>,
85            distances::MathematicalResult<u32>,
86        >,
87{
88    #[inline(always)]
89    fn evaluate(
90        &mut self,
91        query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
92        doc: MatRef<'_, MinMaxMeta<NBITS>>,
93    ) {
94        assert!(
95            self.size() == query.num_vectors(),
96            "scores buffer not right size : {} != {}",
97            self.size(),
98            query.num_vectors()
99        );
100
101        let _ = MinMaxKernel::max_sim_kernel(query, doc, |i, score| {
102            // SAFETY: We asserted that self.size() == query.num_vectors(),
103            // and i < query.num_vectors() due to the kernel loop bound.
104            let _ = self.set(i, score);
105        });
106    }
107}
108
109/////////////
110// Chamfer //
111/////////////
112
113impl<const NBITS: usize>
114    PureDistanceFunction<QueryMatRef<'_, MinMaxMeta<NBITS>>, MatRef<'_, MinMaxMeta<NBITS>>, f32>
115    for Chamfer
116where
117    Unsigned: Representation<NBITS>,
118    distances::InnerProduct: for<'a, 'b> PureDistanceFunction<
119            crate::bits::BitSlice<'a, NBITS, Unsigned>,
120            crate::bits::BitSlice<'b, NBITS, Unsigned>,
121            distances::MathematicalResult<u32>,
122        >,
123{
124    #[inline(always)]
125    fn evaluate(
126        query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
127        doc: MatRef<'_, MinMaxMeta<NBITS>>,
128    ) -> f32 {
129        let mut sum = 0.0f32;
130
131        let _ = MinMaxKernel::max_sim_kernel(query, doc, |_i, score| {
132            sum += score;
133        });
134
135        sum
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::CompressInto;
143    use crate::algorithms::Transform;
144    use crate::algorithms::transforms::NullTransform;
145    use crate::bits::{Representation, Unsigned};
146    use crate::minmax::{Data, MinMaxQuantizer};
147    use crate::multi_vector::{Defaulted, Mat, Standard};
148    use crate::num::Positive;
149    use diskann_utils::ReborrowMut;
150    use std::num::NonZeroUsize;
151
152    macro_rules! expand_to_bitrates {
153        ($name:ident, $func:ident) => {
154            #[test]
155            fn $name() {
156                $func::<1>();
157                $func::<2>();
158                $func::<4>();
159                $func::<8>();
160            }
161        };
162    }
163
164    /// Test cases: (num_queries, num_docs, dim)
165    const TEST_CASES: &[(usize, usize, usize)] = &[
166        (1, 1, 4),   // Single query, single doc
167        (1, 5, 8),   // Single query, multiple docs
168        (5, 1, 8),   // Multiple queries, single doc
169        (3, 4, 16),  // General case
170        (7, 7, 32),  // Square case
171        (2, 3, 128), // Larger dimension
172    ];
173
174    fn make_quantizer(dim: usize) -> MinMaxQuantizer {
175        MinMaxQuantizer::new(
176            Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
177            Positive::new(1.0).unwrap(),
178        )
179    }
180
181    fn generate_input_mat(n: usize, dim: usize, offset: usize) -> Vec<f32> {
182        (0..n * dim)
183            .map(|idx| {
184                let i = idx / dim;
185                let j = idx % dim;
186                ((i + offset) * dim + j) as f32 * 0.1
187            })
188            .collect()
189    }
190
191    fn compress_mat<const NBITS: usize>(
192        quantizer: &MinMaxQuantizer,
193        input: &[f32],
194        n: usize,
195        dim: usize,
196    ) -> Mat<MinMaxMeta<NBITS>>
197    where
198        Unsigned: Representation<NBITS>,
199    {
200        let input_mat = MatRef::new(Standard::<f32>::new(n, dim).unwrap(), input).unwrap();
201        let mut output: Mat<MinMaxMeta<NBITS>> =
202            Mat::new(MinMaxMeta::new(n, dim), Defaulted).unwrap();
203        quantizer
204            .compress_into(input_mat, output.reborrow_mut())
205            .unwrap();
206        output
207    }
208
209    /// Naive max-sim for one query vector: min distance to any doc vector.
210    fn naive_max_sim_single<const NBITS: usize>(
211        query: DataRef<'_, NBITS>,
212        doc: &MatRef<'_, MinMaxMeta<NBITS>>,
213    ) -> f32
214    where
215        Unsigned: Representation<NBITS>,
216        distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
217                crate::bits::BitSlice<'x, NBITS, Unsigned>,
218                crate::bits::BitSlice<'y, NBITS, Unsigned>,
219                distances::MathematicalResult<u32>,
220            >,
221    {
222        doc.rows()
223            .map(|d| {
224                <MinMaxIP as PureDistanceFunction<
225                    DataRef<'_, NBITS>,
226                    DataRef<'_, NBITS>,
227                    distances::Result<f32>,
228                >>::evaluate(query, d)
229                .unwrap()
230            })
231            .fold(f32::MAX, f32::min)
232    }
233
234    fn test_matches_naive<const NBITS: usize>()
235    where
236        Unsigned: Representation<NBITS>,
237        distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
238                crate::bits::BitSlice<'x, NBITS, Unsigned>,
239                crate::bits::BitSlice<'y, NBITS, Unsigned>,
240                distances::MathematicalResult<u32>,
241            >,
242    {
243        for &(nq, nd, dim) in TEST_CASES {
244            let quantizer = make_quantizer(dim);
245
246            let query_data = generate_input_mat(nq, dim, 0);
247            let doc_data = generate_input_mat(nd, dim, nq);
248
249            let query_mat = compress_mat::<NBITS>(&quantizer, &query_data, nq, dim);
250            let doc_mat = compress_mat::<NBITS>(&quantizer, &doc_data, nd, dim);
251
252            let query: QueryMatRef<_> = query_mat.as_view().into();
253            let doc = doc_mat.as_view();
254
255            // Test MaxSim matches naive
256            let expected: Vec<f32> = query
257                .rows()
258                .map(|q| naive_max_sim_single(q, &doc))
259                .collect();
260
261            let mut scores = vec![0.0f32; nq];
262            MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
263
264            for (i, (&got, &exp)) in scores.iter().zip(expected.iter()).enumerate() {
265                assert!(
266                    (got - exp).abs() < 1e-5,
267                    "NBITS={NBITS} ({nq},{nd},{dim}) MaxSim[{i}]: {got} != {exp}"
268                );
269            }
270
271            // Test kernel matches MaxSim
272            let mut kernel_scores = vec![0.0f32; nq];
273            MinMaxKernel::max_sim_kernel(query, doc, |i, s| kernel_scores[i] = s).unwrap();
274            assert_eq!(
275                scores, kernel_scores,
276                "NBITS={NBITS} ({nq},{nd},{dim}) kernel mismatch"
277            );
278
279            // Test Chamfer equals sum of MaxSim
280            let chamfer = Chamfer::evaluate(query, doc);
281            let sum: f32 = scores.iter().sum();
282            assert!(
283                (chamfer - sum).abs() < 1e-4,
284                "NBITS={NBITS} ({nq},{nd},{dim}) Chamfer {chamfer} != sum {sum}"
285            );
286        }
287    }
288
289    expand_to_bitrates!(matches_naive, test_matches_naive);
290
291    #[test]
292    #[should_panic(expected = "scores buffer not right size")]
293    fn max_sim_panics_on_size_mismatch() {
294        let dim = 4;
295        let row_bytes = Data::<8>::canonical_bytes(dim);
296        let query_data = vec![0u8; 2 * row_bytes];
297        let doc_data = vec![0u8; 3 * row_bytes];
298
299        let query: QueryMatRef<_> = MatRef::new(MinMaxMeta::<8>::new(2, dim), &query_data)
300            .unwrap()
301            .into();
302        let doc = MatRef::new(MinMaxMeta::<8>::new(3, dim), &doc_data).unwrap();
303
304        let mut scores = vec![0.0f32; 5]; // Wrong size
305        MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
306    }
307}