1use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
7
8use super::super::vectors::{DataRef, MinMaxIP};
9use super::meta::MinMaxMeta;
10use crate::bits::{Representation, Unsigned};
11use crate::distances::{self, UnequalLengths};
12use crate::multi_vector::distance::QueryMatRef;
13use crate::multi_vector::{Chamfer, MatRef, MaxSim};
14
15pub struct MinMaxKernel;
24
25impl MinMaxKernel {
26 #[inline(always)]
37 pub(crate) fn max_sim_kernel<const NBITS: usize, F>(
38 query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
39 doc: MatRef<'_, MinMaxMeta<NBITS>>,
40 mut f: F,
41 ) -> Result<(), UnequalLengths>
42 where
43 Unsigned: Representation<NBITS>,
44 distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
45 crate::bits::BitSlice<'x, NBITS, Unsigned>,
46 crate::bits::BitSlice<'y, NBITS, Unsigned>,
47 distances::MathematicalResult<u32>,
48 >,
49 F: FnMut(usize, f32),
50 {
51 for (i, q_ref) in query.rows().enumerate() {
52 let mut min_distance = f32::MAX;
54
55 for d_ref in doc.rows() {
56 let dist = <MinMaxIP as PureDistanceFunction<
58 DataRef<'_, NBITS>,
59 DataRef<'_, NBITS>,
60 distances::Result<f32>,
61 >>::evaluate(q_ref, d_ref)?;
62
63 min_distance = min_distance.min(dist);
64 }
65
66 f(i, min_distance);
67 }
68
69 Ok(())
70 }
71}
72
73impl<const NBITS: usize>
78 DistanceFunctionMut<QueryMatRef<'_, MinMaxMeta<NBITS>>, MatRef<'_, MinMaxMeta<NBITS>>>
79 for MaxSim<'_>
80where
81 Unsigned: Representation<NBITS>,
82 distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
83 crate::bits::BitSlice<'x, NBITS, Unsigned>,
84 crate::bits::BitSlice<'y, NBITS, Unsigned>,
85 distances::MathematicalResult<u32>,
86 >,
87{
88 #[inline(always)]
89 fn evaluate(
90 &mut self,
91 query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
92 doc: MatRef<'_, MinMaxMeta<NBITS>>,
93 ) {
94 assert!(
95 self.size() == query.num_vectors(),
96 "scores buffer not right size : {} != {}",
97 self.size(),
98 query.num_vectors()
99 );
100
101 let _ = MinMaxKernel::max_sim_kernel(query, doc, |i, score| {
102 let _ = self.set(i, score);
105 });
106 }
107}
108
109impl<const NBITS: usize>
114 PureDistanceFunction<QueryMatRef<'_, MinMaxMeta<NBITS>>, MatRef<'_, MinMaxMeta<NBITS>>, f32>
115 for Chamfer
116where
117 Unsigned: Representation<NBITS>,
118 distances::InnerProduct: for<'a, 'b> PureDistanceFunction<
119 crate::bits::BitSlice<'a, NBITS, Unsigned>,
120 crate::bits::BitSlice<'b, NBITS, Unsigned>,
121 distances::MathematicalResult<u32>,
122 >,
123{
124 #[inline(always)]
125 fn evaluate(
126 query: QueryMatRef<'_, MinMaxMeta<NBITS>>,
127 doc: MatRef<'_, MinMaxMeta<NBITS>>,
128 ) -> f32 {
129 let mut sum = 0.0f32;
130
131 let _ = MinMaxKernel::max_sim_kernel(query, doc, |_i, score| {
132 sum += score;
133 });
134
135 sum
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::CompressInto;
143 use crate::algorithms::Transform;
144 use crate::algorithms::transforms::NullTransform;
145 use crate::bits::{Representation, Unsigned};
146 use crate::minmax::{Data, MinMaxQuantizer};
147 use crate::multi_vector::{Defaulted, Mat, Standard};
148 use crate::num::Positive;
149 use diskann_utils::ReborrowMut;
150 use std::num::NonZeroUsize;
151
152 macro_rules! expand_to_bitrates {
153 ($name:ident, $func:ident) => {
154 #[test]
155 fn $name() {
156 $func::<1>();
157 $func::<2>();
158 $func::<4>();
159 $func::<8>();
160 }
161 };
162 }
163
164 const TEST_CASES: &[(usize, usize, usize)] = &[
166 (1, 1, 4), (1, 5, 8), (5, 1, 8), (3, 4, 16), (7, 7, 32), (2, 3, 128), ];
173
174 fn make_quantizer(dim: usize) -> MinMaxQuantizer {
175 MinMaxQuantizer::new(
176 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
177 Positive::new(1.0).unwrap(),
178 )
179 }
180
181 fn generate_input_mat(n: usize, dim: usize, offset: usize) -> Vec<f32> {
182 (0..n * dim)
183 .map(|idx| {
184 let i = idx / dim;
185 let j = idx % dim;
186 ((i + offset) * dim + j) as f32 * 0.1
187 })
188 .collect()
189 }
190
191 fn compress_mat<const NBITS: usize>(
192 quantizer: &MinMaxQuantizer,
193 input: &[f32],
194 n: usize,
195 dim: usize,
196 ) -> Mat<MinMaxMeta<NBITS>>
197 where
198 Unsigned: Representation<NBITS>,
199 {
200 let input_mat = MatRef::new(Standard::<f32>::new(n, dim).unwrap(), input).unwrap();
201 let mut output: Mat<MinMaxMeta<NBITS>> =
202 Mat::new(MinMaxMeta::new(n, dim), Defaulted).unwrap();
203 quantizer
204 .compress_into(input_mat, output.reborrow_mut())
205 .unwrap();
206 output
207 }
208
209 fn naive_max_sim_single<const NBITS: usize>(
211 query: DataRef<'_, NBITS>,
212 doc: &MatRef<'_, MinMaxMeta<NBITS>>,
213 ) -> f32
214 where
215 Unsigned: Representation<NBITS>,
216 distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
217 crate::bits::BitSlice<'x, NBITS, Unsigned>,
218 crate::bits::BitSlice<'y, NBITS, Unsigned>,
219 distances::MathematicalResult<u32>,
220 >,
221 {
222 doc.rows()
223 .map(|d| {
224 <MinMaxIP as PureDistanceFunction<
225 DataRef<'_, NBITS>,
226 DataRef<'_, NBITS>,
227 distances::Result<f32>,
228 >>::evaluate(query, d)
229 .unwrap()
230 })
231 .fold(f32::MAX, f32::min)
232 }
233
234 fn test_matches_naive<const NBITS: usize>()
235 where
236 Unsigned: Representation<NBITS>,
237 distances::InnerProduct: for<'x, 'y> PureDistanceFunction<
238 crate::bits::BitSlice<'x, NBITS, Unsigned>,
239 crate::bits::BitSlice<'y, NBITS, Unsigned>,
240 distances::MathematicalResult<u32>,
241 >,
242 {
243 for &(nq, nd, dim) in TEST_CASES {
244 let quantizer = make_quantizer(dim);
245
246 let query_data = generate_input_mat(nq, dim, 0);
247 let doc_data = generate_input_mat(nd, dim, nq);
248
249 let query_mat = compress_mat::<NBITS>(&quantizer, &query_data, nq, dim);
250 let doc_mat = compress_mat::<NBITS>(&quantizer, &doc_data, nd, dim);
251
252 let query: QueryMatRef<_> = query_mat.as_view().into();
253 let doc = doc_mat.as_view();
254
255 let expected: Vec<f32> = query
257 .rows()
258 .map(|q| naive_max_sim_single(q, &doc))
259 .collect();
260
261 let mut scores = vec![0.0f32; nq];
262 MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
263
264 for (i, (&got, &exp)) in scores.iter().zip(expected.iter()).enumerate() {
265 assert!(
266 (got - exp).abs() < 1e-5,
267 "NBITS={NBITS} ({nq},{nd},{dim}) MaxSim[{i}]: {got} != {exp}"
268 );
269 }
270
271 let mut kernel_scores = vec![0.0f32; nq];
273 MinMaxKernel::max_sim_kernel(query, doc, |i, s| kernel_scores[i] = s).unwrap();
274 assert_eq!(
275 scores, kernel_scores,
276 "NBITS={NBITS} ({nq},{nd},{dim}) kernel mismatch"
277 );
278
279 let chamfer = Chamfer::evaluate(query, doc);
281 let sum: f32 = scores.iter().sum();
282 assert!(
283 (chamfer - sum).abs() < 1e-4,
284 "NBITS={NBITS} ({nq},{nd},{dim}) Chamfer {chamfer} != sum {sum}"
285 );
286 }
287 }
288
289 expand_to_bitrates!(matches_naive, test_matches_naive);
290
291 #[test]
292 #[should_panic(expected = "scores buffer not right size")]
293 fn max_sim_panics_on_size_mismatch() {
294 let dim = 4;
295 let row_bytes = Data::<8>::canonical_bytes(dim);
296 let query_data = vec![0u8; 2 * row_bytes];
297 let doc_data = vec![0u8; 3 * row_bytes];
298
299 let query: QueryMatRef<_> = MatRef::new(MinMaxMeta::<8>::new(2, dim), &query_data)
300 .unwrap()
301 .into();
302 let doc = MatRef::new(MinMaxMeta::<8>::new(3, dim), &doc_data).unwrap();
303
304 let mut scores = vec![0.0f32; 5]; MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
306 }
307}