Skip to main content

diskann_quantization/multi_vector/distance/
simple.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4//! Simple kernel implementation of multi-vector distance computation.
5
6use std::ops::Deref;
7
8use diskann_vector::distance::InnerProduct;
9use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
10
11use super::max_sim::{Chamfer, MaxSim};
12use crate::multi_vector::{MatRef, MaxSimError, Repr, Standard};
13
14/////////////////
15// QueryMatRef //
16/////////////////
17
18/// A query matrix view for asymmetric distance functions.
19///
20/// This wrapper distinguishes query matrices from document matrices
21/// at compile time, preventing accidental argument swapping in asymmetric
22/// distance computations like [`MaxSim`] and [`Chamfer`].
23///
24/// # Example
25///
26/// ```
27/// use diskann_quantization::multi_vector::{MatRef, Standard};
28/// use diskann_quantization::multi_vector::distance::QueryMatRef;
29///
30/// let data = [1.0f32, 2.0, 3.0, 4.0];
31/// let view = MatRef::new(Standard::new(2, 2).unwrap(), &data).unwrap();
32/// let query: QueryMatRef<_> = view.into();
33/// ```
34#[derive(Debug, Clone, Copy)]
35pub struct QueryMatRef<'a, T: Repr>(pub MatRef<'a, T>);
36
37impl<'a, T: Repr> From<MatRef<'a, T>> for QueryMatRef<'a, T> {
38    fn from(view: MatRef<'a, T>) -> Self {
39        Self(view)
40    }
41}
42
43/// Deref so that we can transparently access the `MatRef` in distance functions.
44impl<'a, T: Repr> Deref for QueryMatRef<'a, T> {
45    type Target = MatRef<'a, T>;
46
47    fn deref(&self) -> &Self::Target {
48        &self.0
49    }
50}
51
52//////////////////
53// SimpleKernel //
54//////////////////
55
56/// Simple double-loop kernel to compute max-sim distances over multi-vectors.
57///
58/// This kernel performs a simple double-loop over the rows of `query`
59/// and the `doc` and dispatches to [`InnerProduct`] to compute the similarity.
60pub struct SimpleKernel;
61
62impl SimpleKernel {
63    /// Core kernel for computing per-query-vector max similarities (min negated inner-product).
64    ///
65    /// For each `query` vector, computes the maximum similarity (negated inner product)
66    /// to any document vector, then calls `f(index, score)` with the result. If
67    /// there are no vectors in the `doc`, the kernel returns immediately.
68    ///
69    /// The callback can be used to aggregate or set scores as needed - as is the
70    /// case with [`MaxSim`] and [`Chamfer`].
71    ///
72    /// # Arguments
73    ///
74    /// * `query` - The query multi-vector (wrapped as [`QueryMatRef`])
75    /// * `doc` - The document multi-vector
76    /// * `f` - Callback invoked with `(query_index, similarity)` for each query vector
77    #[inline]
78    pub(crate) fn max_sim_kernel<F, T: Copy>(
79        query: QueryMatRef<'_, Standard<T>>,
80        doc: MatRef<'_, Standard<T>>,
81        mut f: F,
82    ) where
83        F: FnMut(usize, f32),
84        InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
85    {
86        // Early exit if no doc vectors - callback should never be invoked
87        if doc.num_vectors() == 0 {
88            return;
89        }
90
91        for (i, q_vec) in query.rows().enumerate() {
92            // `InnerProduct::evaluate` returns negated inner product
93            let mut min_dist = f32::MAX;
94
95            for d_vec in doc.rows() {
96                let dist = InnerProduct::evaluate(q_vec, d_vec);
97                min_dist = min_dist.min(dist);
98            }
99
100            f(i, min_dist);
101        }
102    }
103}
104
105////////////
106// MaxSim //
107////////////
108
109impl<T: Copy>
110    DistanceFunctionMut<
111        QueryMatRef<'_, Standard<T>>,
112        MatRef<'_, Standard<T>>,
113        Result<(), MaxSimError>,
114    > for MaxSim<'_>
115where
116    InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
117{
118    #[inline(always)]
119    fn evaluate(
120        &mut self,
121        query: QueryMatRef<'_, Standard<T>>,
122        doc: MatRef<'_, Standard<T>>,
123    ) -> Result<(), MaxSimError> {
124        let size = self.size();
125        let n_queries = query.num_vectors();
126
127        if self.size() != query.num_vectors() {
128            return Err(MaxSimError::InvalidBufferLength(size, n_queries));
129        }
130
131        SimpleKernel::max_sim_kernel(query, doc, |i, score| {
132            // SAFETY: We asserted that self.size() == query.num_vectors(),
133            // and i < query.num_vectors() due to the kernel loop bound.
134            unsafe { *self.scores.get_unchecked_mut(i) = score };
135        });
136
137        Ok(())
138    }
139}
140
141/////////////
142// Chamfer //
143/////////////
144
145impl<T: Copy> PureDistanceFunction<QueryMatRef<'_, Standard<T>>, MatRef<'_, Standard<T>>, f32>
146    for Chamfer
147where
148    InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
149{
150    #[inline(always)]
151    fn evaluate(query: QueryMatRef<'_, Standard<T>>, doc: MatRef<'_, Standard<T>>) -> f32 {
152        let mut sum = 0.0f32;
153
154        SimpleKernel::max_sim_kernel(query, doc, |_i, score| {
155            sum += score;
156        });
157
158        sum
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    /// Helper to create a QueryMatRef from raw data
167    fn make_query(data: &[f32], nrows: usize, ncols: usize) -> QueryMatRef<'_, Standard<f32>> {
168        MatRef::new(Standard::new(nrows, ncols).unwrap(), data)
169            .unwrap()
170            .into()
171    }
172
173    /// Helper to create a MatRef from raw data
174    fn make_doc(data: &[f32], nrows: usize, ncols: usize) -> MatRef<'_, Standard<f32>> {
175        MatRef::new(Standard::new(nrows, ncols).unwrap(), data).unwrap()
176    }
177
178    /// Naive implementation of max-sim for a single query vector against all doc vectors.
179    fn naive_max_sim_single(query_vec: &[f32], doc: &MatRef<'_, Standard<f32>>) -> f32 {
180        doc.rows()
181            .map(|d_vec| {
182                let ip: f32 = query_vec.iter().zip(d_vec.iter()).map(|(a, b)| a * b).sum();
183                -ip
184            })
185            .fold(f32::MAX, f32::min)
186    }
187
188    /// Generate a vector of random f32 values in [-1, 1] for testing
189    fn make_test_data(len: usize, ceil: usize, shift: usize) -> Vec<f32> {
190        (0..len).map(|v| ((v + shift) % ceil) as f32).collect()
191    }
192
193    mod query_mat_ref {
194        use super::*;
195
196        #[test]
197        fn from_mat_ref_and_deref() {
198            let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
199            let view = MatRef::new(Standard::new(2, 3).unwrap(), &data).unwrap();
200            let query: QueryMatRef<_> = view.into();
201
202            // Deref access works
203            assert_eq!(query.num_vectors(), 2);
204            assert_eq!(query.vector_dim(), 3);
205            assert_eq!(query.get_row(0), Some(&[1.0f32, 2.0, 3.0][..]));
206        }
207
208        #[test]
209        fn is_copy() {
210            let data = [1.0f32, 2.0];
211            let query = make_query(&data, 1, 2);
212            let copy = query;
213            let _ = (query, copy); // Both usable
214        }
215    }
216
217    mod distance_functions {
218        use diskann_utils::Reborrow;
219
220        use super::*;
221
222        #[test]
223        fn max_sim_panics_on_size_mismatch() {
224            let query = make_query(&[1.0, 2.0, 3.0, 4.0], 2, 2);
225            let doc = make_doc(&[1.0, 1.0], 1, 2);
226
227            let mut scores = vec![0.0f32; 3]; // Wrong size
228            let r = MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
229            assert!(r.is_err());
230        }
231
232        /// Tests both MaxSim and Chamfer against naive implementations across
233        /// various matrix sizes including edge cases (single row/col).
234        #[test]
235        fn matches_naive_implementation() {
236            let test_cases = [
237                (1, 1, 4),   // Single query, single doc
238                (1, 5, 8),   // Single query, multiple docs
239                (5, 1, 8),   // Multiple queries, single doc
240                (3, 4, 16),  // General case
241                (7, 7, 32),  // Square case
242                (2, 3, 128), // Larger dimension
243            ];
244
245            for (nq, nd, dim) in test_cases.iter() {
246                let query_data = make_test_data(nq * dim, *dim, dim / 2);
247                let doc_data = make_test_data(nd * dim, *dim, *dim);
248
249                let query = make_query(&query_data, *nq, *dim);
250                let doc = make_doc(&doc_data, *nd, *dim);
251
252                // Test MaxSim
253                let mut scores = vec![0.0f32; *nq];
254                let r = MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
255                assert!(r.is_ok());
256
257                let expected_scores: Vec<f32> = query
258                    .rows()
259                    .map(|q_vec| naive_max_sim_single(q_vec, &doc))
260                    .collect();
261
262                for i in 0..*nq {
263                    assert!(
264                        (scores[i] - expected_scores[i]).abs() < 1e-5,
265                        "MaxSim mismatch at {} for ({},{},{})",
266                        i,
267                        nq,
268                        nd,
269                        dim
270                    );
271                }
272
273                // Check that SimpleKernel is also correct.
274                SimpleKernel::max_sim_kernel(query, doc, |i, score| {
275                    assert!((scores[i] - score).abs() <= 1e-6)
276                });
277
278                // Test Chamfer
279                let chamfer = Chamfer::evaluate(query, doc);
280                let expected_chamfer: f32 = expected_scores.iter().sum();
281
282                assert!(
283                    (chamfer - expected_chamfer).abs() < 1e-4,
284                    "Chamfer mismatch for ({},{},{})",
285                    nq,
286                    nd,
287                    dim
288                );
289            }
290        }
291
292        #[test]
293        fn chamfer_with_zero_queries_returns_zero() {
294            let query = make_query(&[], 0, 2);
295            let doc = make_doc(&[1.0, 0.0, 0.0, 1.0], 2, 2);
296
297            let result = Chamfer::evaluate(query, doc);
298
299            // No query vectors means sum is 0
300            assert_eq!(result, 0.0);
301
302            let result = Chamfer::evaluate(doc.into(), query.deref().reborrow());
303
304            assert_eq!(result, 0.0);
305        }
306    }
307}