diskann_quantization/multi_vector/distance/
simple.rs1use std::ops::Deref;
7
8use diskann_vector::distance::InnerProduct;
9use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
10
11use super::max_sim::{Chamfer, MaxSim};
12use crate::multi_vector::{MatRef, MaxSimError, Repr, Standard};
13
14#[derive(Debug, Clone, Copy)]
35pub struct QueryMatRef<'a, T: Repr>(pub MatRef<'a, T>);
36
37impl<'a, T: Repr> From<MatRef<'a, T>> for QueryMatRef<'a, T> {
38 fn from(view: MatRef<'a, T>) -> Self {
39 Self(view)
40 }
41}
42
43impl<'a, T: Repr> Deref for QueryMatRef<'a, T> {
45 type Target = MatRef<'a, T>;
46
47 fn deref(&self) -> &Self::Target {
48 &self.0
49 }
50}
51
52pub struct SimpleKernel;
61
62impl SimpleKernel {
63 #[inline]
78 pub(crate) fn max_sim_kernel<F, T: Copy>(
79 query: QueryMatRef<'_, Standard<T>>,
80 doc: MatRef<'_, Standard<T>>,
81 mut f: F,
82 ) where
83 F: FnMut(usize, f32),
84 InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
85 {
86 if doc.num_vectors() == 0 {
88 return;
89 }
90
91 for (i, q_vec) in query.rows().enumerate() {
92 let mut min_dist = f32::MAX;
94
95 for d_vec in doc.rows() {
96 let dist = InnerProduct::evaluate(q_vec, d_vec);
97 min_dist = min_dist.min(dist);
98 }
99
100 f(i, min_dist);
101 }
102 }
103}
104
105impl<T: Copy>
110 DistanceFunctionMut<
111 QueryMatRef<'_, Standard<T>>,
112 MatRef<'_, Standard<T>>,
113 Result<(), MaxSimError>,
114 > for MaxSim<'_>
115where
116 InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
117{
118 #[inline(always)]
119 fn evaluate(
120 &mut self,
121 query: QueryMatRef<'_, Standard<T>>,
122 doc: MatRef<'_, Standard<T>>,
123 ) -> Result<(), MaxSimError> {
124 let size = self.size();
125 let n_queries = query.num_vectors();
126
127 if self.size() != query.num_vectors() {
128 return Err(MaxSimError::InvalidBufferLength(size, n_queries));
129 }
130
131 SimpleKernel::max_sim_kernel(query, doc, |i, score| {
132 unsafe { *self.scores.get_unchecked_mut(i) = score };
135 });
136
137 Ok(())
138 }
139}
140
141impl<T: Copy> PureDistanceFunction<QueryMatRef<'_, Standard<T>>, MatRef<'_, Standard<T>>, f32>
146 for Chamfer
147where
148 InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
149{
150 #[inline(always)]
151 fn evaluate(query: QueryMatRef<'_, Standard<T>>, doc: MatRef<'_, Standard<T>>) -> f32 {
152 let mut sum = 0.0f32;
153
154 SimpleKernel::max_sim_kernel(query, doc, |_i, score| {
155 sum += score;
156 });
157
158 sum
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn make_query(data: &[f32], nrows: usize, ncols: usize) -> QueryMatRef<'_, Standard<f32>> {
168 MatRef::new(Standard::new(nrows, ncols).unwrap(), data)
169 .unwrap()
170 .into()
171 }
172
173 fn make_doc(data: &[f32], nrows: usize, ncols: usize) -> MatRef<'_, Standard<f32>> {
175 MatRef::new(Standard::new(nrows, ncols).unwrap(), data).unwrap()
176 }
177
178 fn naive_max_sim_single(query_vec: &[f32], doc: &MatRef<'_, Standard<f32>>) -> f32 {
180 doc.rows()
181 .map(|d_vec| {
182 let ip: f32 = query_vec.iter().zip(d_vec.iter()).map(|(a, b)| a * b).sum();
183 -ip
184 })
185 .fold(f32::MAX, f32::min)
186 }
187
188 fn make_test_data(len: usize, ceil: usize, shift: usize) -> Vec<f32> {
190 (0..len).map(|v| ((v + shift) % ceil) as f32).collect()
191 }
192
193 mod query_mat_ref {
194 use super::*;
195
196 #[test]
197 fn from_mat_ref_and_deref() {
198 let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
199 let view = MatRef::new(Standard::new(2, 3).unwrap(), &data).unwrap();
200 let query: QueryMatRef<_> = view.into();
201
202 assert_eq!(query.num_vectors(), 2);
204 assert_eq!(query.vector_dim(), 3);
205 assert_eq!(query.get_row(0), Some(&[1.0f32, 2.0, 3.0][..]));
206 }
207
208 #[test]
209 fn is_copy() {
210 let data = [1.0f32, 2.0];
211 let query = make_query(&data, 1, 2);
212 let copy = query;
213 let _ = (query, copy); }
215 }
216
217 mod distance_functions {
218 use diskann_utils::Reborrow;
219
220 use super::*;
221
222 #[test]
223 fn max_sim_panics_on_size_mismatch() {
224 let query = make_query(&[1.0, 2.0, 3.0, 4.0], 2, 2);
225 let doc = make_doc(&[1.0, 1.0], 1, 2);
226
227 let mut scores = vec![0.0f32; 3]; let r = MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
229 assert!(r.is_err());
230 }
231
232 #[test]
235 fn matches_naive_implementation() {
236 let test_cases = [
237 (1, 1, 4), (1, 5, 8), (5, 1, 8), (3, 4, 16), (7, 7, 32), (2, 3, 128), ];
244
245 for (nq, nd, dim) in test_cases.iter() {
246 let query_data = make_test_data(nq * dim, *dim, dim / 2);
247 let doc_data = make_test_data(nd * dim, *dim, *dim);
248
249 let query = make_query(&query_data, *nq, *dim);
250 let doc = make_doc(&doc_data, *nd, *dim);
251
252 let mut scores = vec![0.0f32; *nq];
254 let r = MaxSim::new(&mut scores).unwrap().evaluate(query, doc);
255 assert!(r.is_ok());
256
257 let expected_scores: Vec<f32> = query
258 .rows()
259 .map(|q_vec| naive_max_sim_single(q_vec, &doc))
260 .collect();
261
262 for i in 0..*nq {
263 assert!(
264 (scores[i] - expected_scores[i]).abs() < 1e-5,
265 "MaxSim mismatch at {} for ({},{},{})",
266 i,
267 nq,
268 nd,
269 dim
270 );
271 }
272
273 SimpleKernel::max_sim_kernel(query, doc, |i, score| {
275 assert!((scores[i] - score).abs() <= 1e-6)
276 });
277
278 let chamfer = Chamfer::evaluate(query, doc);
280 let expected_chamfer: f32 = expected_scores.iter().sum();
281
282 assert!(
283 (chamfer - expected_chamfer).abs() < 1e-4,
284 "Chamfer mismatch for ({},{},{})",
285 nq,
286 nd,
287 dim
288 );
289 }
290 }
291
292 #[test]
293 fn chamfer_with_zero_queries_returns_zero() {
294 let query = make_query(&[], 0, 2);
295 let doc = make_doc(&[1.0, 0.0, 0.0, 1.0], 2, 2);
296
297 let result = Chamfer::evaluate(query, doc);
298
299 assert_eq!(result, 0.0);
301
302 let result = Chamfer::evaluate(doc.into(), query.deref().reborrow());
303
304 assert_eq!(result, 0.0);
305 }
306 }
307}