diskann_quantization/multi_vector/distance/query_computer/
mod.rs1mod f16;
35mod f32;
36
37use crate::multi_vector::{BlockTransposed, MatRef, Standard};
38
39#[derive(Debug)]
41pub struct QueryComputer<T: Copy> {
42 inner: Box<dyn DynQueryComputer<T>>,
43}
44
45impl<T: Copy> QueryComputer<T> {
46 #[inline]
48 pub fn nrows(&self) -> usize {
49 self.inner.nrows()
50 }
51
52 pub fn chamfer(&self, doc: MatRef<'_, Standard<T>>) -> f32 {
56 let nq = self.nrows();
57 if doc.num_vectors() == 0 {
58 return 0.0;
59 }
60 let mut scores = vec![0.0f32; nq];
61 self.max_sim(doc, &mut scores);
62 scores.iter().sum()
63 }
64
65 pub fn max_sim(&self, doc: MatRef<'_, Standard<T>>, scores: &mut [f32]) {
74 let nq = self.nrows();
75 assert_eq!(
76 scores.len(),
77 nq,
78 "scores buffer not right size: {} != {}",
79 scores.len(),
80 nq
81 );
82
83 if doc.num_vectors() == 0 {
84 return;
85 }
86
87 self.inner.compute_max_sim(doc, scores);
88 }
89}
90
91trait DynQueryComputer<T: Copy>: std::fmt::Debug + Send + Sync {
92 fn compute_max_sim(&self, doc: MatRef<'_, Standard<T>>, scores: &mut [f32]);
93 fn nrows(&self) -> usize;
94}
95
96#[derive(Debug)]
97struct Prepared<A, Q> {
98 arch: A,
99 prepared: Q,
100}
101
102fn build_prepared<T: Copy + Default, A, const GROUP: usize>(
103 arch: A,
104 query: MatRef<'_, Standard<T>>,
105) -> Prepared<A, BlockTransposed<T, GROUP>> {
106 let prepared = BlockTransposed::<T, GROUP>::from_matrix_view(query.as_matrix_view());
107 Prepared { arch, prepared }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::multi_vector::{Chamfer, MaxSim, QueryMatRef};
114 use diskann_vector::distance::InnerProduct;
115 use diskann_vector::{DistanceFunctionMut, PureDistanceFunction};
116
117 trait FromF32 {
118 fn from_f32(v: f32) -> Self;
119 }
120
121 impl FromF32 for f32 {
122 fn from_f32(v: f32) -> Self {
123 v
124 }
125 }
126
127 impl FromF32 for half::f16 {
128 fn from_f32(v: f32) -> Self {
129 diskann_wide::cast_f32_to_f16(v)
130 }
131 }
132
133 fn make_mat<T: Copy>(data: &[T], nrows: usize, ncols: usize) -> MatRef<'_, Standard<T>> {
134 MatRef::new(Standard::new(nrows, ncols).unwrap(), data).unwrap()
135 }
136
137 fn make_test_data<T: FromF32>(len: usize, ceil: usize, shift: usize) -> Vec<T> {
138 (0..len)
139 .map(|v| T::from_f32(((v + shift) % ceil) as f32))
140 .collect()
141 }
142
143 const TEST_CASES: &[(usize, usize, usize)] = &[
156 (1, 1, 4), (5, 3, 5), (17, 4, 64), (16, 6, 32), ];
162
163 fn check_chamfer_matches<T: Copy + FromF32>(
164 build: fn(MatRef<'_, Standard<T>>) -> QueryComputer<T>,
165 tol: f32,
166 label: &str,
167 ) where
168 InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
169 {
170 for &(nq, nd, dim) in TEST_CASES {
171 let query_data = make_test_data::<T>(nq * dim, dim, dim / 2);
172 let doc_data = make_test_data::<T>(nd * dim, dim, dim);
173
174 let query = make_mat(&query_data, nq, dim);
175 let doc = make_mat(&doc_data, nd, dim);
176
177 let expected = Chamfer::evaluate(QueryMatRef::from(query), doc);
178 let actual = build(query).chamfer(doc);
179
180 assert!(
181 (actual - expected).abs() < tol,
182 "{label}Chamfer mismatch for ({nq},{nd},{dim}): actual={actual}, expected={expected}",
183 );
184 }
185 }
186
187 fn check_max_sim_matches<T: Copy + FromF32>(
188 build: fn(MatRef<'_, Standard<T>>) -> QueryComputer<T>,
189 tol: f32,
190 label: &str,
191 ) where
192 InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>,
193 {
194 for &(nq, nd, dim) in TEST_CASES {
195 let query_data = make_test_data::<T>(nq * dim, dim, dim / 2);
196 let doc_data = make_test_data::<T>(nd * dim, dim, dim);
197
198 let query = make_mat(&query_data, nq, dim);
199 let doc = make_mat(&doc_data, nd, dim);
200
201 let mut expected_scores = vec![0.0f32; nq];
202 let _ = MaxSim::new(&mut expected_scores)
203 .unwrap()
204 .evaluate(QueryMatRef::from(query), doc);
205
206 let computer = build(query);
207 let mut actual_scores = vec![0.0f32; nq];
208 computer.max_sim(doc, &mut actual_scores);
209
210 for i in 0..nq {
211 assert!(
212 (actual_scores[i] - expected_scores[i]).abs() < tol,
213 "{label}MaxSim[{i}] mismatch for ({nq},{nd},{dim}): actual={}, expected={}",
214 actual_scores[i],
215 expected_scores[i],
216 );
217 }
218 }
219 }
220
221 #[test]
222 fn query_computer_dimensions() {
223 let data = vec![1.0f32; 5 * 8];
224 let query = make_mat(&data, 5, 8);
225 let computer = QueryComputer::<f32>::new(query);
226
227 assert_eq!(computer.nrows(), 5);
228 }
229
230 #[test]
231 fn query_computer_f16_dimensions() {
232 let data = vec![diskann_wide::cast_f32_to_f16(1.0); 5 * 8];
233 let query = make_mat(data.as_slice(), 5, 8);
234 let computer = QueryComputer::<half::f16>::new(query);
235
236 assert_eq!(computer.nrows(), 5);
237 }
238
239 #[test]
240 fn chamfer_with_zero_docs() {
241 let query = make_mat(&[1.0f32, 0.0, 0.0, 1.0], 2, 2);
242 let computer = QueryComputer::<f32>::new(query);
243 let doc = make_mat(&[], 0, 2);
244 assert_eq!(computer.chamfer(doc), 0.0);
245 }
246
247 #[test]
248 fn max_sim_with_zero_docs() {
249 let query = make_mat(&[1.0f32, 0.0, 0.0, 1.0], 2, 2);
250 let computer = QueryComputer::<f32>::new(query);
251 let doc = make_mat::<f32>(&[], 0, 2);
252 let mut scores = vec![0.0f32; 2];
253 computer.max_sim(doc, &mut scores);
254 for &s in &scores {
256 assert_eq!(s, 0.0, "zero-doc MaxSim should leave scores untouched");
257 }
258 }
259
260 #[test]
261 #[should_panic(expected = "scores buffer not right size")]
262 fn max_sim_panics_on_size_mismatch() {
263 let query = make_mat(&[1.0f32, 2.0, 3.0, 4.0], 2, 2);
264 let computer = QueryComputer::<f32>::new(query);
265 let doc = make_mat(&[1.0, 1.0], 1, 2);
266 let mut scores = vec![0.0f32; 3]; computer.max_sim(doc, &mut scores);
268 }
269
270 macro_rules! test_matches_fallback {
271 ($mod_name:ident, $ty:ty, $tol:expr, $label:literal) => {
272 mod $mod_name {
273 use super::*;
274
275 #[test]
276 fn chamfer_matches_fallback() {
277 check_chamfer_matches(QueryComputer::<$ty>::new, $tol, $label);
278 }
279
280 #[test]
281 fn max_sim_matches_fallback() {
282 check_max_sim_matches(QueryComputer::<$ty>::new, $tol, $label);
283 }
284 }
285 };
286 }
287
288 test_matches_fallback!(f32, f32, 1e-10, "f32 ");
289 test_matches_fallback!(f16, half::f16, 1e-10, "f16 ");
290}